diff --git a/.gitignore b/.gitignore index debad77ec2ad..07524bc429e9 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log @@ -74,3 +75,7 @@ metastore/ warehouse/ TempStatsStore/ sql/hive-thriftserver/test_warehouses + +# For R session data +.RHistory +.RData diff --git a/.rat-excludes b/.rat-excludes index 236c2db05367..7262c960ed6b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js @@ -93,3 +81,6 @@ INDEX .lintr gen-java.* .*avpr +org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory +.*parquet diff --git a/LICENSE b/LICENSE index f9e412cade34..a2f75b817ab3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -211,717 +210,50 @@ subcomponents is subject to the terms and conditions of the following licenses. -======================================================================= -For the Boto EC2 library (ec2/third_party/boto*.zip): -======================================================================= - -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. - - -======================================================================== -For CloudPickle (pyspark/cloudpickle.py): -======================================================================== - -Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the University of California, Berkeley nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For Py4J (python/lib/py4j-0.8.2.1-src.zip) -======================================================================== - -Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. - -- The name of the author may not be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For DPark join code (python/pyspark/join.py): -======================================================================== - -Copyright (c) 2011, Douban Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - - * Neither the name of the Douban Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ======================================================================== For heapq (pyspark/heapq3.py): ======================================================================== -# A. HISTORY OF THE SOFTWARE -# ========================== -# -# Python was created in the early 1990s by Guido van Rossum at Stichting -# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands -# as a successor of a language called ABC. Guido remains Python's -# principal author, although it includes many contributions from others. -# -# In 1995, Guido continued his work on Python at the Corporation for -# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) -# in Reston, Virginia where he released several versions of the -# software. -# -# In May 2000, Guido and the Python core development team moved to -# BeOpen.com to form the BeOpen PythonLabs team. In October of the same -# year, the PythonLabs team moved to Digital Creations (now Zope -# Corporation, see http://www.zope.com). In 2001, the Python Software -# Foundation (PSF, see http://www.python.org/psf/) was formed, a -# non-profit organization created specifically to own Python-related -# Intellectual Property. Zope Corporation is a sponsoring member of -# the PSF. -# -# All Python releases are Open Source (see http://www.opensource.org for -# the Open Source Definition). Historically, most, but not all, Python -# releases have also been GPL-compatible; the table below summarizes -# the various releases. -# -# Release Derived Year Owner GPL- -# from compatible? (1) -# -# 0.9.0 thru 1.2 1991-1995 CWI yes -# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes -# 1.6 1.5.2 2000 CNRI no -# 2.0 1.6 2000 BeOpen.com no -# 1.6.1 1.6 2001 CNRI yes (2) -# 2.1 2.0+1.6.1 2001 PSF no -# 2.0.1 2.0+1.6.1 2001 PSF yes -# 2.1.1 2.1+2.0.1 2001 PSF yes -# 2.2 2.1.1 2001 PSF yes -# 2.1.2 2.1.1 2002 PSF yes -# 2.1.3 2.1.2 2002 PSF yes -# 2.2.1 2.2 2002 PSF yes -# 2.2.2 2.2.1 2002 PSF yes -# 2.2.3 2.2.2 2003 PSF yes -# 2.3 2.2.2 2002-2003 PSF yes -# 2.3.1 2.3 2002-2003 PSF yes -# 2.3.2 2.3.1 2002-2003 PSF yes -# 2.3.3 2.3.2 2002-2003 PSF yes -# 2.3.4 2.3.3 2004 PSF yes -# 2.3.5 2.3.4 2005 PSF yes -# 2.4 2.3 2004 PSF yes -# 2.4.1 2.4 2005 PSF yes -# 2.4.2 2.4.1 2005 PSF yes -# 2.4.3 2.4.2 2006 PSF yes -# 2.4.4 2.4.3 2006 PSF yes -# 2.5 2.4 2006 PSF yes -# 2.5.1 2.5 2007 PSF yes -# 2.5.2 2.5.1 2008 PSF yes -# 2.5.3 2.5.2 2008 PSF yes -# 2.6 2.5 2008 PSF yes -# 2.6.1 2.6 2008 PSF yes -# 2.6.2 2.6.1 2009 PSF yes -# 2.6.3 2.6.2 2009 PSF yes -# 2.6.4 2.6.3 2009 PSF yes -# 2.6.5 2.6.4 2010 PSF yes -# 2.7 2.6 2010 PSF yes -# -# Footnotes: -# -# (1) GPL-compatible doesn't mean that we're distributing Python under -# the GPL. All Python licenses, unlike the GPL, let you distribute -# a modified version without making your changes open source. The -# GPL-compatible licenses make it possible to combine Python with -# other software that is released under the GPL; the others don't. -# -# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, -# because its license has a choice of law clause. According to -# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 -# is "not incompatible" with the GPL. -# -# Thanks to the many outside volunteers who have worked under Guido's -# direction to make these releases possible. -# -# -# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON -# =============================================================== -# -# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -# -------------------------------------------- -# -# 1. This LICENSE AGREEMENT is between the Python Software Foundation -# ("PSF"), and the Individual or Organization ("Licensee") accessing and -# otherwise using this software ("Python") in source or binary form and -# its associated documentation. -# -# 2. Subject to the terms and conditions of this License Agreement, PSF hereby -# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, -# analyze, test, perform and/or display publicly, prepare derivative works, -# distribute, and otherwise use Python alone or in any derivative version, -# provided, however, that PSF's License Agreement and PSF's notice of copyright, -# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, -# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained -# in Python alone or in any derivative version prepared by Licensee. -# -# 3. In the event Licensee prepares a derivative work that is based on -# or incorporates Python or any part thereof, and wants to make -# the derivative work available to others as provided herein, then -# Licensee hereby agrees to include in any such work a brief summary of -# the changes made to Python. -# -# 4. PSF is making Python available to Licensee on an "AS IS" -# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON -# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS -# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, -# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 6. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 7. Nothing in this License Agreement shall be deemed to create any -# relationship of agency, partnership, or joint venture between PSF and -# Licensee. This License Agreement does not grant permission to use PSF -# trademarks or trade name in a trademark sense to endorse or promote -# products or services of Licensee, or any third party. -# -# 8. By copying, installing or otherwise using Python, Licensee -# agrees to be bound by the terms and conditions of this License -# Agreement. -# -# -# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 -# ------------------------------------------- -# -# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 -# -# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an -# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the -# Individual or Organization ("Licensee") accessing and otherwise using -# this software in source or binary form and its associated -# documentation ("the Software"). -# -# 2. Subject to the terms and conditions of this BeOpen Python License -# Agreement, BeOpen hereby grants Licensee a non-exclusive, -# royalty-free, world-wide license to reproduce, analyze, test, perform -# and/or display publicly, prepare derivative works, distribute, and -# otherwise use the Software alone or in any derivative version, -# provided, however, that the BeOpen Python License is retained in the -# Software, alone or in any derivative version prepared by Licensee. -# -# 3. BeOpen is making the Software available to Licensee on an "AS IS" -# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE -# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS -# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY -# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 5. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 6. This License Agreement shall be governed by and interpreted in all -# respects by the law of the State of California, excluding conflict of -# law provisions. Nothing in this License Agreement shall be deemed to -# create any relationship of agency, partnership, or joint venture -# between BeOpen and Licensee. This License Agreement does not grant -# permission to use BeOpen trademarks or trade names in a trademark -# sense to endorse or promote products or services of Licensee, or any -# third party. As an exception, the "BeOpen Python" logos available at -# http://www.pythonlabs.com/logos.html may be used according to the -# permissions granted on that web page. -# -# 7. By copying, installing or otherwise using the software, Licensee -# agrees to be bound by the terms and conditions of this License -# Agreement. -# -# -# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 -# --------------------------------------- -# -# 1. This LICENSE AGREEMENT is between the Corporation for National -# Research Initiatives, having an office at 1895 Preston White Drive, -# Reston, VA 20191 ("CNRI"), and the Individual or Organization -# ("Licensee") accessing and otherwise using Python 1.6.1 software in -# source or binary form and its associated documentation. -# -# 2. Subject to the terms and conditions of this License Agreement, CNRI -# hereby grants Licensee a nonexclusive, royalty-free, world-wide -# license to reproduce, analyze, test, perform and/or display publicly, -# prepare derivative works, distribute, and otherwise use Python 1.6.1 -# alone or in any derivative version, provided, however, that CNRI's -# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) -# 1995-2001 Corporation for National Research Initiatives; All Rights -# Reserved" are retained in Python 1.6.1 alone or in any derivative -# version prepared by Licensee. Alternately, in lieu of CNRI's License -# Agreement, Licensee may substitute the following text (omitting the -# quotes): "Python 1.6.1 is made available subject to the terms and -# conditions in CNRI's License Agreement. This Agreement together with -# Python 1.6.1 may be located on the Internet using the following -# unique, persistent identifier (known as a handle): 1895.22/1013. This -# Agreement may also be obtained from a proxy server on the Internet -# using the following URL: http://hdl.handle.net/1895.22/1013". -# -# 3. In the event Licensee prepares a derivative work that is based on -# or incorporates Python 1.6.1 or any part thereof, and wants to make -# the derivative work available to others as provided herein, then -# Licensee hereby agrees to include in any such work a brief summary of -# the changes made to Python 1.6.1. -# -# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" -# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON -# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS -# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, -# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 6. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 7. This License Agreement shall be governed by the federal -# intellectual property law of the United States, including without -# limitation the federal copyright law, and, to the extent such -# U.S. federal law does not apply, by the law of the Commonwealth of -# Virginia, excluding Virginia's conflict of law provisions. -# Notwithstanding the foregoing, with regard to derivative works based -# on Python 1.6.1 that incorporate non-separable material that was -# previously distributed under the GNU General Public License (GPL), the -# law of the Commonwealth of Virginia shall govern this License -# Agreement only as to issues arising under or with respect to -# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this -# License Agreement shall be deemed to create any relationship of -# agency, partnership, or joint venture between CNRI and Licensee. This -# License Agreement does not grant permission to use CNRI trademarks or -# trade name in a trademark sense to endorse or promote products or -# services of Licensee, or any third party. -# -# 8. By clicking on the "ACCEPT" button where indicated, or by copying, -# installing or otherwise using Python 1.6.1, Licensee agrees to be -# bound by the terms and conditions of this License Agreement. -# -# ACCEPT -# -# -# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 -# -------------------------------------------------- -# -# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, -# The Netherlands. All rights reserved. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose and without fee is hereby granted, -# provided that the above copyright notice appear in all copies and that -# both that copyright notice and this permission notice appear in -# supporting documentation, and that the name of Stichting Mathematisch -# Centrum or CWI not be used in advertising or publicity pertaining to -# distribution of the software without specific, written prior -# permission. -# -# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO -# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND -# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE -# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -======================================================================== -For sorttable (core/src/main/resources/org/apache/spark/ui/static/sorttable.js): -======================================================================== - -Copyright (c) 1997-2007 Stuart Langridge - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. - -======================================================================== -For d3 (core/src/main/resources/org/apache/spark/ui/static/d3.min.js): -======================================================================== - -Copyright (c) 2010-2015, Michael Bostock -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* The name Michael Bostock may not be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -======================================================================== -For Scala Interpreter classes (all .scala files in repl/src/main/scala -except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), -and for SerializableMapWrapper in JavaUtils.scala: -======================================================================== - -Copyright (c) 2002-2013 EPFL -Copyright (c) 2011-2013 Typesafe, Inc. - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -- Neither the name of the EPFL nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For sbt and sbt-launch-lib.bash in sbt/: -======================================================================== - -// Generated from http://www.opensource.org/licenses/bsd-license.php -Copyright (c) 2011, Paul Phillips. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of the author nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +See license/LICENSE-heapq.txt ======================================================================== For SnapTree: ======================================================================== -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. - - -======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): -======================================================================== -Copyright (C) 2008 The Android Open Source Project - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java): -======================================================================== -Copyright (C) 2015 Stijn de Gouw - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For LimitedInputStream - (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): -======================================================================== -Copyright (C) 2007 The Guava Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For vis.js (core/src/main/resources/org/apache/spark/ui/static/vis.min.js): -======================================================================== -Copyright (C) 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - -and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. +See license/LICENSE-SnapTree.txt ======================================================================== -For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +For jbcrypt: ======================================================================== -Copyright (c) 2013 Chris Pettitt - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -======================================================================== -For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): -======================================================================== -Copyright (c) 2012-2013 Chris Pettitt - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +See license/LICENSE-jbcrypt.txt ======================================================================== BSD-style licenses ======================================================================== The following components are provided under a BSD-style license. See project link for details. +The text of each license is also included at licenses/LICENSE-[project].txt. - (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD style) Hamcrest Core (org.hamcrest:hamcrest-core:1.1 - no url defined) + (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD-like) (The BSD License) jline (org.scala-lang:jline:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.10.4 - http://www.scala-lang.org/) + (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) + (Interpreter classes (all .scala files in repl/src/main/scala + except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), + and for SerializableMapWrapper in JavaUtils.scala) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.10.5 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.10:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.10:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.10:0.7.1 - http://spire-math.org) @@ -932,15 +264,19 @@ The following components are provided under a BSD-style license. See project lin (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.2.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.9 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (ISC/BSD License) jbcrypt (org.mindrot:jbcrypt:0.3m - http://www.mindrot.org/) + (BSD licence) sbt and sbt-launch-lib.bash + (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) + (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) + (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) ======================================================================== MIT licenses ======================================================================== The following components are provided under the MIT License. See project link for details. +The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) @@ -951,3 +287,7 @@ The following components are provided under the MIT License. See project link fo (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) + (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) + (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) + (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) + (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) diff --git a/NOTICE b/NOTICE index 452aef287165..571f8c2fff7f 100644 --- a/NOTICE +++ b/NOTICE @@ -572,3 +572,45 @@ Copyright 2009-2013 The Apache Software Foundation Apache Avro IPC Copyright 2009-2013 The Apache Software Foundation + + +Vis.js +Copyright 2010-2015 Almende B.V. + +Vis.js is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + + and + + * The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. + + +Vis.js uses and redistributes the following third-party libraries: + +- component-emitter + https://github.com/component/emitter + The MIT License + +- hammer.js + http://hammerjs.github.io/ + The MIT License + +- moment.js + http://momentjs.com/ + The MIT License + +- keycharm + https://github.com/AlexDM0/keycharm + The MIT License + +=============================================================================== + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. +=============================================================================== diff --git a/R/create-docs.sh b/R/create-docs.sh index 6a4687b06ecb..d2ae160b5002 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -39,7 +39,7 @@ pushd $FWDIR mkdir -p pkg/html pushd pkg/html -Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' +Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/install-dev.bat b/R/install-dev.bat index f32670b67de9..ed1c91ae3a0f 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -30,3 +30,4 @@ rem Zip the SparkR package so that it can be distributed to worker nodes on YARN pushd %SPARK_HOME%\R\lib %JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR popd + diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e..39c872663ad4 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 4949d86d20c9..369714f7b99c 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.4.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman @@ -29,7 +29,10 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' 'mllib.R' 'serialize.R' 'sparkR.R' + 'stats.R' + 'types.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b2d92bdf4840..ccc01fe16960 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,10 +23,18 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", + "attach", "cache", "collect", + "colnames", + "colnames<-", + "coltypes", + "coltypes<-", "columns", "count", + "cov", + "corr", "crosstab", "describe", "dim", @@ -38,6 +46,7 @@ exportMethods("arrange", "fillna", "filter", "first", + "freqItems", "group_by", "groupBy", "head", @@ -47,12 +56,13 @@ exportMethods("arrange", "join", "limit", "merge", + "mutate", + "na.omit", "names", + "names<-", "ncol", "nrow", "orderBy", - "mutate", - "names", "persist", "printSchema", "rbind", @@ -61,6 +71,7 @@ exportMethods("arrange", "repartition", "sample", "sample_frac", + "sampleBy", "saveAsParquetFile", "saveAsTable", "saveDF", @@ -69,72 +80,181 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "subset", "summarize", "summary", "take", + "transform", "unionAll", "unique", "unpersist", "where", + "with", "withColumn", "withColumnRenamed", - "write.df") + "write.df", + "write.json", + "write.parquet") exportClasses("Column") -exportMethods("abs", +exportMethods("%in%", + "abs", "acos", + "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", + "ascii", "asin", "atan", "atan2", "avg", + "base64", "between", + "bin", + "bitwiseNOT", "cast", "cbrt", + "ceil", "ceiling", + "column", + "concat", + "concat_ws", "contains", + "conv", "cos", "cosh", + "count", "countDistinct", + "crc32", + "cume_dist", + "date_add", + "date_format", + "date_sub", + "datediff", + "dayofmonth", + "dayofyear", + "decode", + "dense_rank", "desc", + "encode", "endsWith", "exp", + "explode", "expm1", + "expr", + "factorial", + "first", "floor", + "format_number", + "format_string", + "from_unixtime", + "from_utc_timestamp", "getField", "getItem", + "greatest", + "hex", + "hour", "hypot", + "ifelse", + "initcap", + "instr", + "isNaN", "isNotNull", "isNull", + "kurtosis", + "lag", "last", + "last_day", + "lead", + "least", + "length", + "levenshtein", "like", + "lit", + "locate", "log", "log10", "log1p", + "log2", "lower", + "lpad", + "ltrim", "max", + "md5", "mean", "min", + "minute", + "month", + "months_between", "n", "n_distinct", + "nanvl", + "negate", + "next_day", + "ntile", + "otherwise", + "percent_rank", + "pmod", + "quarter", + "rand", + "randn", + "rank", + "regexp_extract", + "regexp_replace", + "reverse", "rint", "rlike", + "round", + "row_number", + "rpad", + "rtrim", + "second", + "sha1", + "sha2", + "shiftLeft", + "shiftRight", + "shiftRightUnsigned", + "sd", "sign", + "signum", "sin", "sinh", + "size", + "skewness", + "sort_array", + "soundex", + "stddev", + "stddev_pop", + "stddev_samp", + "struct", "sqrt", "startsWith", "substr", + "substring_index", "sum", "sumDistinct", "tan", "tanh", "toDegrees", "toRadians", - "upper") + "to_date", + "to_utc_timestamp", + "translate", + "trim", + "unbase64", + "unhex", + "unix_timestamp", + "upper", + "var", + "variance", + "var_pop", + "var_samp", + "weekofyear", + "when", + "year") exportClasses("GroupedData") exportMethods("agg") @@ -142,15 +262,18 @@ exportMethods("agg") export("sparkRSQL.init", "sparkRHive.init") -export("cacheTable", +export("as.DataFrame", + "cacheTable", "clearCache", "createDataFrame", "createExternalTable", "dropTempTable", "jsonFile", + "read.json", "loadDF", "parquetFile", "read.df", + "read.parquet", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 895603235011..0cfa12b997d6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -23,14 +23,23 @@ NULL setOldClass("jobj") #' @title S4 class that represents a DataFrame -#' @description DataFrames can be created using functions like -#' \code{jsonFile}, \code{table} etc. +#' @description DataFrames can be created using functions like \link{createDataFrame}, +#' \link{read.json}, \link{table} etc. +#' @family DataFrame functions #' @rdname DataFrame -#' @seealso jsonFile, table +#' @docType class #' -#' @param env An R environment that stores bookkeeping states of the DataFrame -#' @param sdf A Java object reference to the backing Scala DataFrame +#' @slot env An R environment that stores bookkeeping states of the DataFrame +#' @slot sdf A Java object reference to the backing Scala DataFrame +#' @seealso \link{createDataFrame}, \link{read.json}, \link{table} +#' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- createDataFrame(sqlContext, faithful) +#'} setClass("DataFrame", slots = list(env = "environment", sdf = "jobj")) @@ -45,7 +54,6 @@ setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { #' @rdname DataFrame #' @export -#' #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the dataFrame is cached dataFrame <- function(sdf, isCached = FALSE) { @@ -60,14 +68,16 @@ dataFrame <- function(sdf, isCached = FALSE) { #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname printSchema +#' @name printSchema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -83,14 +93,16 @@ setMethod("printSchema", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname schema +#' @name schema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -105,14 +117,16 @@ setMethod("schema", #' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @family DataFrame functions #' @rdname explain +#' @name explain #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -134,14 +148,16 @@ setMethod("explain", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname isLocal +#' @name isLocal #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -157,14 +173,16 @@ setMethod("isLocal", #' @param x A SparkSQL DataFrame #' @param numRows The number of rows to print. Defaults to 20. #' +#' @family DataFrame functions #' @rdname showDF +#' @name showDF #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -180,14 +198,16 @@ setMethod("showDF", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname show +#' @name show #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -205,14 +225,16 @@ setMethod("show", "DataFrame", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname dtypes +#' @name dtypes #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -229,15 +251,19 @@ setMethod("dtypes", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname columns +#' @name columns + #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -248,7 +274,7 @@ setMethod("columns", }) #' @rdname columns -#' @aliases names,DataFrame,function-method +#' @name names setMethod("names", signature(x = "DataFrame"), function(x) { @@ -256,15 +282,131 @@ setMethod("names", }) #' @rdname columns +#' @name names<- setMethod("names<-", signature(x = "DataFrame"), function(x, value) { if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) dataFrame(sdf) } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @return value A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @name coltypes +#' @family DataFrame functions +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param value A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. +#' @rdname coltypes +#' @name coltypes<- +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + if (!is.na(value[i])) { + stype <- rToSQLTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -272,14 +414,16 @@ setMethod("names<-", #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table #' +#' @family DataFrame functions #' @rdname registerTempTable +#' @name registerTempTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "json_df") #' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} @@ -298,7 +442,9 @@ setMethod("registerTempTable", #' @param overwrite A logical argument indicating whether or not to overwrite #' the existing rows in the table. #' +#' @family DataFrame functions #' @rdname insertInto +#' @name insertInto #' @export #' @examples #'\dontrun{ @@ -321,14 +467,16 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' -#' @rdname cache-methods +#' @family DataFrame functions +#' @rdname cache +#' @name cache #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -343,17 +491,20 @@ setMethod("cache", #' #' Persist this DataFrame with the specified storage level. For details of the #' supported storage levels, refer to -#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The DataFrame to persist +#' +#' @family DataFrame functions #' @rdname persist +#' @name persist #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -371,14 +522,17 @@ setMethod("persist", #' #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted +#' +#' @family DataFrame functions #' @rdname unpersist-methods +#' @name unpersist #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -396,14 +550,17 @@ setMethod("unpersist", #' #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. +#' +#' @family DataFrame functions #' @rdname repartition +#' @name repartition #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -413,23 +570,24 @@ setMethod("repartition", dataFrame(sdf) }) -# toJSON -# -# Convert the rows of a DataFrame into JSON objects and return an RDD where -# each element contains a JSON string. -# -#@param x A SparkSQL DataFrame -# @return A StringRRDD of JSON objects -# @rdname tojson -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# newRDD <- toJSON(df) -#} +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @family DataFrame functions +#' @rdname tojson +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' newRDD <- toJSON(df) +#'} setMethod("toJSON", signature(x = "DataFrame"), function(x) { @@ -438,27 +596,69 @@ setMethod("toJSON", RDD(jrdd, serializedMode = "string") }) -#' saveAsParquetFile +#' write.json +#' +#' Save the contents of a DataFrame as a JSON file (one object per line). Files written out +#' with this method can be read back in as a DataFrame using read.json(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.json +#' @name write.json +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' write.json(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.json", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "json", path)) + }) + +#' write.parquet #' #' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out -#' with this method can be read back in as a DataFrame using parquetFile(). +#' with this method can be read back in as a DataFrame using read.parquet(). #' #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved -#' @rdname saveAsParquetFile +#' +#' @family DataFrame functions +#' @rdname write.parquet +#' @name write.parquet #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) -#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#' df <- read.json(sqlContext, path) +#' write.parquet(df, "/tmp/sparkr-tmp1/") +#' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} +setMethod("write.parquet", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "parquet", path)) + }) + +#' @rdname write.parquet +#' @name saveAsParquetFile +#' @export setMethod("saveAsParquetFile", signature(x = "DataFrame", path = "character"), function(x, path) { - invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + .Deprecated("write.parquet") + write.parquet(x, path) }) #' Distinct @@ -466,14 +666,17 @@ setMethod("saveAsParquetFile", #' Return a new DataFrame containing the distinct rows in this DataFrame. #' #' @param x A SparkSQL DataFrame +#' +#' @family DataFrame functions #' @rdname distinct +#' @name distinct #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -483,12 +686,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @rdname unique -#' @aliases unique +#' @rdname distinct +#' @name unique setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -502,52 +701,61 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' +#' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) #' @rdname sample -#' @aliases sample +#' @name sample_frac setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' -#' @rdname count +#' @family DataFrame functions +#' @rdname nrow +#' @name count #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' count(df) #' } setMethod("count", @@ -556,13 +764,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -573,14 +776,16 @@ setMethod("nrow", #' #' @param x a SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname ncol +#' @name ncol #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' ncol(df) #' } setMethod("ncol", @@ -592,14 +797,16 @@ setMethod("ncol", #' Returns the dimentions (number of rows and columns) of a DataFrame #' @param x a SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname dim +#' @name dim #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dim(df) #' } setMethod("dim", @@ -613,32 +820,69 @@ setMethod("dim", #' @param x A SparkSQL DataFrame #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. - -#' @rdname collect-methods +#' +#' @family DataFrame functions +#' @rdname collect +#' @name collect #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - # listCols is a list of raw vectors, one per column - listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) - cols <- lapply(listCols, function(col) { - objRaw <- rawConnection(col) - numRows <- readInt(objRaw) - col <- readCol(objRaw, numRows) - close(objRaw) - col - }) - names(cols) <- columns(x) - do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + dtypes <- dtypes(x) + ncol <- length(dtypes) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + # + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[colIndex]] <- col + } else { + colType <- dtypes[[colIndex]][[2]] + # Note that "binary" columns behave like complex types. + if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { + vec <- do.call(c, col) + stopifnot(class(vec) != "list") + df[[colIndex]] <- vec + } else { + df[[colIndex]] <- col + } + } + } + names(df) <- names(x) + df + } }) #' Limit @@ -649,14 +893,16 @@ setMethod("collect", #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. #' +#' @family DataFrame functions #' @rdname limit +#' @name limit #' @export #' @examples #' \dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -668,14 +914,16 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' +#' @family DataFrame functions #' @rdname take +#' @name take #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -695,14 +943,16 @@ setMethod("take", #' @param num The number of rows to return. Default is 6. #' @return A data.frame #' +#' @family DataFrame functions #' @rdname head +#' @name head #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' head(df) #' } setMethod("head", @@ -716,14 +966,16 @@ setMethod("head", #' #' @param x A SparkSQL DataFrame #' +#' @family DataFrame functions #' @rdname first +#' @name first #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' first(df) #' } setMethod("first", @@ -732,22 +984,21 @@ setMethod("first", take(x, 1) }) -# toRDD() -# -# Converts a Spark DataFrame to an RDD while preserving column names. -# -# @param x A Spark DataFrame -# -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# rdd <- toRDD(df) -# } +#' toRDD +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' rdd <- toRDD(df) +#'} setMethod("toRDD", signature(x = "DataFrame"), function(x) { @@ -767,8 +1018,9 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by +#' @family DataFrame functions #' @rdname groupBy +#' @name groupBy #' @export #' @examples #' \dontrun{ @@ -783,16 +1035,16 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) #' @rdname groupBy -#' @aliases group_by +#' @name group_by setMethod("group_by", signature(x = "DataFrame"), function(x, ...) { @@ -804,8 +1056,9 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @rdname DataFrame -#' @aliases summarize +#' @family DataFrame functions +#' @rdname agg +#' @name agg #' @export setMethod("agg", signature(x = "DataFrame"), @@ -813,8 +1066,8 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @rdname DataFrame -#' @aliases agg +#' @rdname agg +#' @name summarize setMethod("summarize", signature(x = "DataFrame"), function(x, ...) { @@ -828,7 +1081,8 @@ setMethod("summarize", # the requested map function. # ################################################################################### -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("lapply", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -836,14 +1090,16 @@ setMethod("lapply", lapply(rdd, FUN) }) -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("map", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# @rdname flatMap +#' @rdname flatMap +#' @noRd setMethod("flatMap", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -851,7 +1107,8 @@ setMethod("flatMap", flatMap(rdd, FUN) }) -# @rdname lapplyPartition +#' @rdname lapplyPartition +#' @noRd setMethod("lapplyPartition", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -859,14 +1116,16 @@ setMethod("lapplyPartition", lapplyPartition(rdd, FUN) }) -# @rdname lapplyPartition +#' @rdname lapplyPartition +#' @noRd setMethod("mapPartitions", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreach", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -874,7 +1133,8 @@ setMethod("foreach", foreach(rdd, func) }) -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreachPartition", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -890,12 +1150,14 @@ getColumn <- function(x, c) { } #' @rdname select +#' @name $ setMethod("$", signature(x = "DataFrame"), function(x, name) { getColumn(x, name) }) #' @rdname select +#' @name $<- setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) @@ -922,8 +1184,11 @@ setMethod("$<-", signature(x = "DataFrame"), x }) -#' @rdname select -setMethod("[[", signature(x = "DataFrame"), +setClassUnion("numericOrcharacter", c("numeric", "character")) + +#' @rdname subset +#' @name [[ +setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { if (is.numeric(i)) { cols <- columns(x) @@ -932,7 +1197,8 @@ setMethod("[[", signature(x = "DataFrame"), getColumn(x, i) }) -#' @rdname select +#' @rdname subset +#' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { if (is.numeric(j)) { @@ -945,6 +1211,56 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) +#' @rdname subset +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "Column"), + function(x, i, j, ...) { + # It could handle i as "character" but it seems confusing and not required + # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html + filtered <- filter(x, i) + if (!missing(j)) { + filtered[, j, ...] + } else { + filtered + } + }) + +#' Subset +#' +#' Return subsets of DataFrame according to given conditions +#' @param x A DataFrame +#' @param subset (Optional) A logical expression to filter on rows +#' @param select expression for the single Column or a list of columns to select from the DataFrame +#' @return A new DataFrame containing only the rows that meet the condition with selected columns +#' @export +#' @family DataFrame functions +#' @rdname subset +#' @name subset +#' @family subsetting functions +#' @examples +#' \dontrun{ +#' # Columns can be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' df[,c("name", "age")] +#' # Or to filter rows +#' df[df$age > 20,] +#' # DataFrame can be subset on both rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] +#' subset(df, df$age %in% c(19, 30), 1:2) +#' subset(df, df$age %in% c(19), select = c(1,2)) +#' subset(df, select = c(1,2)) +#' } +setMethod("subset", signature(x = "DataFrame"), + function(x, subset, select, ...) { + if (missing(subset)) { + x[, select, ...] + } else { + x[subset, select, ...] + } + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -952,7 +1268,10 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' @param col A list of columns or single Column or name #' @return A new DataFrame with selected columns #' @export +#' @family DataFrame functions #' @rdname select +#' @name select +#' @family subsetting functions #' @examples #' \dontrun{ #' select(df, "*") @@ -960,18 +1279,24 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Columns can also be selected using `[[` and `[` -#' df[[2]] == df[["age"]] -#' df[,2] == df[,"age"] #' # Similar to R data frames columns can also be selected using `$` -#' df$age +#' df[,df$age] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) - dataFrame(sdf) + if (length(col) > 1) { + if (length(list(...)) > 0) { + stop("To select multiple columns, use a character vector or list for col") + } + + select(x, as.list(col)) + } else { + sdf <- callJMethod(x@sdf, "select", col, list(...)) + dataFrame(sdf) + } }) +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", signature(x = "DataFrame", col = "Column"), @@ -979,10 +1304,11 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", @@ -995,7 +1321,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -1007,21 +1333,23 @@ setMethod("select", #' @param expr A string containing a SQL expression #' @param ... Additional expressions #' @return A DataFrame +#' @family DataFrame functions #' @rdname selectExpr +#' @name selectExpr #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -1033,14 +1361,17 @@ setMethod("selectExpr", #' @param colName A string containing the name of the new column. #' @param col A Column expression. #' @return A DataFrame with the new column added. +#' @family DataFrame functions #' @rdname withColumn +#' @name withColumn +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -1048,29 +1379,32 @@ setMethod("withColumn", function(x, colName, col) { select(x, x$"*", alias(col, colName)) }) - #' Mutate #' #' Return a new DataFrame with the specified columns added. #' -#' @param x A DataFrame +#' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. -#' @rdname withColumn -#' @aliases withColumn +#' @family DataFrame functions +#' @rdname mutate +#' @name mutate +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 +#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) #' } setMethod("mutate", - signature(x = "DataFrame"), - function(x, ...) { + signature(.data = "DataFrame"), + function(.data, ...) { + x <- .data cols <- list(...) stopifnot(length(cols) > 0) stopifnot(class(cols[[1]]) == "Column") @@ -1085,7 +1419,16 @@ setMethod("mutate", do.call(select, c(x, x$"*", cols)) }) -#' WithColumnRenamed +#' @export +#' @rdname mutate +#' @name transform +setMethod("transform", + signature(`_data` = "DataFrame"), + function(`_data`, ...) { + mutate(`_data`, ...) + }) + +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1093,14 +1436,17 @@ setMethod("mutate", #' @param existingCol The name of the column you want to change. #' @param newCol The new column name. #' @return A DataFrame with the column name changed. -#' @rdname withColumnRenamed +#' @family DataFrame functions +#' @rdname rename +#' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1116,22 +1462,16 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @rdname withColumnRenamed -#' @aliases withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename +#' @name rename #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1161,37 +1501,72 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' Sort a DataFrame by the specified column(s). #' #' @param x A DataFrame to be sorted. -#' @param col Either a Column object or character vector indicating the field to sort on +#' @param col A character or Column object vector indicating the fields to sort on #' @param ... Additional sorting fields +#' @param decreasing A logical argument indicating sorting order for columns when +#' a character vector is specified for col #' @return A DataFrame where all elements are sorted. +#' @family DataFrame functions #' @rdname arrange +#' @name arrange #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' arrange(df, df$col1) -#' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) +#' arrange(df, "col1", decreasing = TRUE) +#' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) #' } setMethod("arrange", - signature(x = "DataFrame", col = "characterOrColumn"), + signature(x = "DataFrame", col = "Column"), function(x, col, ...) { - if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) - } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) - } + + sdf <- callJMethod(x@sdf, "sort", jcols) dataFrame(sdf) }) #' @rdname arrange -#' @aliases orderBy,DataFrame,function-method +#' @name arrange +#' @export +setMethod("arrange", + signature(x = "DataFrame", col = "character"), + function(x, col, ..., decreasing = FALSE) { + + # all sorting columns + by <- list(col, ...) + + if (length(decreasing) == 1) { + # in case only 1 boolean argument - decreasing value is specified, + # it will be used for all columns + decreasing <- rep(decreasing, length(by)) + } else if (length(decreasing) != length(by)) { + stop("Arguments 'col' and 'decreasing' must have the same length") + } + + # builds a list of columns of type Column + # example: [[1]] Column Species ASC + # [[2]] Column Petal_Length DESC + jcols <- lapply(seq_len(length(decreasing)), function(i){ + if (decreasing[[i]]) { + desc(getColumn(x, by[[i]])) + } else { + asc(getColumn(x, by[[i]])) + } + }) + + do.call("arrange", c(x, jcols)) + }) + +#' @rdname arrange +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1206,14 +1581,17 @@ setMethod("orderBy", #' @param condition The condition to filter on. This may either be a Column expression #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. +#' @family DataFrame functions #' @rdname filter +#' @name filter +#' @family subsetting functions #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1227,8 +1605,9 @@ setMethod("filter", dataFrame(sdf) }) +#' @family DataFrame functions #' @rdname filter -#' @aliases where,DataFrame,function-method +#' @name where setMethod("where", signature(x = "DataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1242,18 +1621,22 @@ setMethod("where", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a -#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' Column expression. If joinExpr is omitted, join() will perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', +#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. +#' @family DataFrame functions #' @rdname join +#' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1268,27 +1651,164 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + if (joinType %in% c("inner", "outer", "full", "fullouter", + "leftouter", "left_outer", "left", + "rightouter", "right_outer", "right", "leftsemi")) { + joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', + 'rightouter', 'right_outer', 'right', 'leftsemi'") } } } dataFrame(sdf) }) -#' rdname merge -#' aliases join +#' @name merge +#' @title Merges two data frames +#' @param x the first data frame to be joined +#' @param y the second data frame to be joined +#' @param by a character vector specifying the join columns. If by is not +#' specified, the common column names in \code{x} and \code{y} will be used. +#' @param by.x a character vector specifying the joining columns for x. +#' @param by.y a character vector specifying the joining columns for y. +#' @param all.x a boolean value indicating whether all the rows in x should +#' be including in the join +#' @param all.y a boolean value indicating whether all the rows in y should +#' be including in the join +#' @param sort a logical argument indicating whether the resulting columns should be sorted +#' @details If all.x and all.y are set to FALSE, a natural join will be returned. If +#' all.x is set to TRUE and all.y is set to FALSE, a left outer join will +#' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right +#' outer join will be returned. If all.x and all.y are set to TRUE, a full +#' outer join will be returned. +#' @family DataFrame functions +#' @rdname merge +#' @seealso \link{join} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) +#' merge(df1, df2) # Performs a Cartesian +#' merge(df1, df2, by = "col1") # Performs an inner join based on expression +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE, all.y = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all = TRUE, sort = FALSE) +#' merge(df1, df2, by = "col1", all = TRUE, suffixes = c("-X", "-Y")) +#' } setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), - function(x, y, joinExpr = NULL, joinType = NULL, ...) { - join(x, y, joinExpr, joinType) - }) + function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, + all = FALSE, all.x = all, all.y = all, + sort = TRUE, suffixes = c("_x","_y"), ... ) { + + if (length(suffixes) != 2) { + stop("suffixes must have length 2") + } + # join type is identified based on the values of all, all.x and all.y + # default join type is inner, according to R it should be natural but since it + # is not supported in spark inner join is used + joinType <- "inner" + if (all || (all.x && all.y)) { + joinType <- "outer" + } else if (all.x) { + joinType <- "left_outer" + } else if (all.y) { + joinType <- "right_outer" + } + + # join expression is based on by.x, by.y if both by.x and by.y are not missing + # or on by, if by.x or by.y are missing or have different lengths + if (length(by.x) > 0 && length(by.x) == length(by.y)) { + joinX <- by.x + joinY <- by.y + } else if (length(by) > 0) { + # if join columns have the same name for both dataframes, + # they are used in join expression + joinX <- by + joinY <- by + } else { + # if by or both by.x and by.y have length 0, use Cartesian Product + joinRes <- join(x, y) + return (joinRes) + } -#' UnionAll + # sets alias for making colnames unique in dataframes 'x' and 'y' + colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + + # selects columns with their aliases from dataframes + # in case same column names are present in both data frames + xsel <- select(x, colsX) + ysel <- select(y, colsY) + + # generates join conditions and adds them into a list + # it also considers alias names of the columns while generating join conditions + joinColumns <- lapply(seq_len(length(joinX)), function(i) { + colX <- joinX[[i]] + colY <- joinY[[i]] + + if (colX %in% by) { + colX <- paste(colX, suffixes[1], sep = "") + } + if (colY %in% by) { + colY <- paste(colY, suffixes[2], sep = "") + } + + colX <- getColumn(xsel, colX) + colY <- getColumn(ysel, colY) + + colX == colY + }) + + # concatenates join columns with '&' and executes join + joinExpr <- Reduce("&", joinColumns) + joinRes <- join(xsel, ysel, joinExpr, joinType) + + # sorts the result by 'by' columns if sort = TRUE + if (sort && length(by) > 0) { + colNameWithSuffix <- paste(by, suffixes[2], sep = "") + joinRes <- do.call("arrange", c(joinRes, colNameWithSuffix, decreasing = FALSE)) + } + + joinRes + }) + +#' +#' Creates a list of columns by replacing the intersected ones with aliases. +#' The name of the alias column is formed by concatanating the original column name and a suffix. +#' +#' @param x a DataFrame on which the +#' @param intersectedColNames a list of intersected column names +#' @param suffix a suffix for the column name +#' @return list of columns +#' +generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { + allColNames <- names(x) + # sets alias for making colnames unique in dataframe 'x' + cols <- lapply(allColNames, function(colName) { + col <- getColumn(x, colName) + if (colName %in% intersectedColNames) { + newJoin <- paste(colName, suffix, sep = "") + if (newJoin %in% allColNames){ + stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + "Please use different suffixes for the intersected columns.") + } + col <- alias(col, newJoin) + } + col + }) + cols +} + +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1297,14 +1817,16 @@ setMethod("merge", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. -#' @rdname unionAll +#' @family DataFrame functions +#' @rdname rbind +#' @name unionAll #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1315,11 +1837,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -# #' @description Returns a new DataFrame containing rows of all parameters. -# +#' #' @rdname rbind -#' @aliases unionAll +#' @name rbind +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1338,14 +1860,16 @@ setMethod("rbind", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. +#' @family DataFrame functions #' @rdname intersect +#' @name intersect #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1363,14 +1887,16 @@ setMethod("intersect", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. +#' @family DataFrame functions #' @rdname except +#' @name except #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1389,32 +1915,34 @@ setMethod("except", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' +#' @family DataFrame functions #' @rdname write.df +#' @name write.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") +#' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1435,11 +1963,11 @@ setMethod("write.df", }) #' @rdname write.df -#' @aliases saveDF +#' @name saveDF #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ write.df(df, path, source, mode, ...) }) @@ -1452,33 +1980,34 @@ setMethod("saveDF", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' +#' @family DataFrame functions #' @rdname saveAsTable +#' @name saveAsTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", signature(df = "DataFrame", tableName = "character", source = "character", mode = "character"), - function(df, tableName, source = NULL, mode="append", ...){ + function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1495,7 +2024,7 @@ setMethod("saveAsTable", callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1504,14 +2033,16 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @rdname describe +#' @family DataFrame functions +#' @rdname summary +#' @name describe #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -1520,29 +2051,26 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) -#' @rdname describe +#' @rdname summary +#' @name describe setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' #' @rdname summary -#' @aliases describe +#' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) @@ -1561,14 +2089,16 @@ setMethod("summary", #' @param cols Optional list of column names to consider. #' @return A DataFrame #' +#' @family DataFrame functions #' @rdname nafunctions +#' @name dropna #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' dropna(df) #' } setMethod("dropna", @@ -1584,16 +2114,17 @@ setMethod("dropna", naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) -#' @aliases dropna +#' @rdname nafunctions +#' @name na.omit #' @export setMethod("na.omit", - signature(x = "DataFrame"), - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - dropna(x, how, minNonNulls, cols) + signature(object = "DataFrame"), + function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(object, how, minNonNulls, cols) }) #' fillna @@ -1612,16 +2143,16 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' #' @rdname nafunctions +#' @name fillna #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } @@ -1638,17 +2169,15 @@ setMethod("fillna", if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - - # Convert to the named list to an environment to be passed to JVM - valueMap <- new.env() - for (col in colNames) { - # Check each item in the named list is of valid type - v <- value[[col]] + # Check each item in the named list is of valid type + lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { stop("Each item in value should be an integer, numeric or charactor.") } - valueMap[[col]] <- v - } + }) + + # Convert to the named list to an environment to be passed to JVM + valueMap <- convertNamedListToEnv(value) # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { @@ -1666,35 +2195,80 @@ setMethod("fillna", sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) -#' crosstab -#' -#' Computes a pair-wise frequency table of the given columns. Also known as a contingency -#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 -#' non-zero pair frequencies will be returned. +#' This function downloads the contents of a DataFrame into an R's data.frame. +#' Since data.frames are held in memory, ensure that you have enough memory +#' in your system to accommodate the contents. #' -#' @param col1 name of the first column. Distinct items will make the first item of each row. -#' @param col2 name of the second column. Distinct items will make the column names of the output. -#' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of `col1` and the column names will be the distinct values -#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have zero as their counts. +#' @title Download data from a DataFrame into a data.frame +#' @param x a DataFrame +#' @return a data.frame +#' @family DataFrame functions +#' @rdname as.data.frame +#' @examples \dontrun{ #' -#' @rdname statfunctions -#' @export +#' irisDF <- createDataFrame(sqlContext, iris) +#' df <- as.data.frame(irisDF[irisDF$Species == "setosa", ]) +#' } +setMethod("as.data.frame", + signature(x = "DataFrame"), + function(x, ...) { + # Check if additional parameters have been passed + if (length(list(...)) > 0) { + stop(paste("Unused argument(s): ", paste(list(...), collapse=", "))) + } + collect(x) + }) + +#' The specified DataFrame is attached to the R search path. This means that +#' the DataFrame is searched by R when evaluating a variable, so columns in +#' the DataFrame can be accessed by simply giving their names. +#' +#' @family DataFrame functions +#' @rdname attach +#' @title Attach DataFrame to R search path +#' @param what (DataFrame) The DataFrame to attach +#' @param pos (integer) Specify position in search() where to attach. +#' @param name (character) Name to use for the attached DataFrame. Names +#' starting with package: are reserved for library. +#' @param warn.conflicts (logical) If TRUE, warnings are printed about conflicts +#' from attaching the database, unless that DataFrame contains an object +#' @examples +#' \dontrun{ +#' attach(irisDf) +#' summary(Sepal_Width) +#' } +#' @seealso \link{detach} +setMethod("attach", + signature(what = "DataFrame"), + function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { + newEnv <- assignNewEnv(what) + attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) + }) + +#' Evaluate a R expression in an environment constructed from a DataFrame +#' with() allows access to columns of a DataFrame by simply referring to +#' their name. It appends every column of a DataFrame into a new +#' environment. Then, the given expression is evaluated in this new +#' environment. +#' +#' @rdname with +#' @title Evaluate a R expression in an environment constructed from a DataFrame +#' @param data (DataFrame) DataFrame to use for constructing an environment. +#' @param expr (expression) Expression to evaluate. +#' @param ... arguments to be passed to future methods. #' @examples #' \dontrun{ -#' df <- jsonFile(sqlCtx, "/path/to/file.json") -#' ct = crosstab(df, "title", "gender") +#' with(irisDf, nrow(Sepal_Width)) #' } -setMethod("crosstab", - signature(x = "DataFrame", col1 = "character", col2 = "character"), - function(x, col1, col2) { - statFunctions <- callJMethod(x@sdf, "stat") - sct <- callJMethod(statFunctions, "crosstab", col1, col2) - collect(dataFrame(sct)) +#' @seealso \link{attach} +setMethod("with", + signature(data = "DataFrame"), + function(data, expr, ...) { + newEnv <- assignNewEnv(data) + eval(substitute(expr), envir = newEnv, enclos = newEnv) }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 051e441d4e06..00c40c38cabc 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -19,16 +19,15 @@ setOldClass("jobj") -# @title S4 class that represents an RDD -# @description RDD can be created using functions like -# \code{parallelize}, \code{textFile} etc. -# @rdname RDD -# @seealso parallelize, textFile -# -# @slot env An R environment that stores bookkeeping states of the RDD -# @slot jrdd Java object reference to the backing JavaRDD -# to an RDD -# @export +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @noRd setClass("RDD", slots = list(env = "environment", jrdd = "jobj")) @@ -111,14 +110,13 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) .Object }) -# @rdname RDD -# @export -# -# @param jrdd Java object reference to the backing JavaRDD -# @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD -# stores strings, and "row" if the RDD stores the rows of a DataFrame -# @param isCached TRUE if the RDD is cached -# @param isCheckpointed TRUE if the RDD has been checkpointed +#' @rdname RDD +#' @noRd +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, isCheckpointed = FALSE) { new("RDD", jrdd, serializedMode, isCached, isCheckpointed) @@ -201,19 +199,20 @@ setValidity("RDD", ############ Actions and Transformations ############ -# Persist an RDD -# -# Persist this RDD with the default storage level (MEMORY_ONLY). -# -# @param x The RDD to cache -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) -#} -# @rdname cache-methods -# @aliases cache,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +#' @noRd setMethod("cache", signature(x = "RDD"), function(x) { @@ -222,22 +221,23 @@ setMethod("cache", x }) -# Persist an RDD -# -# Persist this RDD with the specified storage level. For details of the -# supported storage levels, refer to -# http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. -# -# @param x The RDD to persist -# @param newLevel The new storage level to be assigned -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# persist(rdd, "MEMORY_AND_DISK") -#} -# @rdname persist -# @aliases persist,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +#' @noRd setMethod("persist", signature(x = "RDD", newLevel = "character"), function(x, newLevel = "MEMORY_ONLY") { @@ -246,21 +246,22 @@ setMethod("persist", x }) -# Unpersist an RDD -# -# Mark the RDD as non-persistent, and remove all blocks for it from memory and -# disk. -# -# @param x The RDD to unpersist -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) # rdd@@env$isCached == TRUE -# unpersist(rdd) # rdd@@env$isCached == FALSE -#} -# @rdname unpersist-methods -# @aliases unpersist,RDD-method +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +#' @noRd setMethod("unpersist", signature(x = "RDD"), function(x) { @@ -269,24 +270,25 @@ setMethod("unpersist", x }) -# Checkpoint an RDD -# -# Mark this RDD for checkpointing. It will be saved to a file inside the -# checkpoint directory set with setCheckpointDir() and all references to its -# parent RDDs will be removed. This function must be called before any job has -# been executed on this RDD. It is strongly recommended that this RDD is -# persisted in memory, otherwise saving it on a file will require recomputation. -# -# @param x The RDD to checkpoint -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "checkpoint") -# rdd <- parallelize(sc, 1:10, 2L) -# checkpoint(rdd) -#} -# @rdname checkpoint-methods -# @aliases checkpoint,RDD-method +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +#' @noRd setMethod("checkpoint", signature(x = "RDD"), function(x) { @@ -296,44 +298,57 @@ setMethod("checkpoint", x }) -# Gets the number of partitions of an RDD -# -# @param x A RDD. -# @return the number of partitions of rdd as an integer. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# numPartitions(rdd) # 2L -#} -# @rdname numPartitions -# @aliases numPartitions,RDD-method +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' getNumPartitions(rdd) # 2L +#'} +#' @rdname getNumPartitions +#' @aliases getNumPartitions,RDD-method +#' @noRd +setMethod("getNumPartitions", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "getNumPartitions") + }) + +#' Gets the number of partitions of an RDD, the same as getNumPartitions. +#' But this function has been deprecated, please use getNumPartitions. +#' +#' @rdname getNumPartitions +#' @aliases numPartitions,RDD-method +#' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { - jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "partitions") - callJMethod(partitions, "size") + .Deprecated("getNumPartitions") + getNumPartitions(x) }) -# Collect elements of an RDD -# -# @description -# \code{collect} returns a list that contains all of the elements in this RDD. -# -# @param x The RDD to collect -# @param ... Other optional arguments to collect -# @param flatten FALSE if the list should not flattened -# @return a list containing elements in the RDD -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# collect(rdd) # list from 1 to 10 -# collectPartition(rdd, 0L) # list from 1 to 5 -#} -# @rdname collect-methods -# @aliases collect,RDD-method +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +#' @noRd setMethod("collect", signature(x = "RDD"), function(x, flatten = TRUE) { @@ -344,12 +359,13 @@ setMethod("collect", }) -# @description -# \code{collectPartition} returns a list that contains all of the elements -# in the specified partition of the RDD. -# @param partitionId the partition to collect (starts from 0) -# @rdname collect-methods -# @aliases collectPartition,integer,RDD-method +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +#' @noRd setMethod("collectPartition", signature(x = "RDD", partitionId = "integer"), function(x, partitionId) { @@ -362,17 +378,18 @@ setMethod("collectPartition", serializedMode = getSerializedMode(x)) }) -# @description -# \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) -# collectAsMap(rdd) # list(`1` = 2, `3` = 4) -#} -# @rdname collect-methods -# @aliases collectAsMap,RDD-method +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +#' @noRd setMethod("collectAsMap", signature(x = "RDD"), function(x) { @@ -382,19 +399,20 @@ setMethod("collectAsMap", as.list(map) }) -# Return the number of elements in the RDD. -# -# @param x The RDD to count -# @return number of elements in the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# count(rdd) # 10 -# length(rdd) # Same as count -#} -# @rdname count -# @aliases count,RDD-method +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +#' @noRd setMethod("count", signature(x = "RDD"), function(x) { @@ -406,55 +424,57 @@ setMethod("count", sum(as.integer(vals)) }) -# Return the number of elements in the RDD -# @export -# @rdname count +#' Return the number of elements in the RDD +#' @rdname count +#' @noRd setMethod("length", signature(x = "RDD"), function(x) { count(x) }) -# Return the count of each unique value in this RDD as a list of -# (value, count) pairs. -# -# Same as countByValue in Spark. -# -# @param x The RDD to count -# @return list of (value, count) pairs, where count is number of each unique -# value in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,3,2,1)) -# countByValue(rdd) # (1,2L), (2,2L), (3,1L) -#} -# @rdname countByValue -# @aliases countByValue,RDD-method +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +#' @noRd setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, numPartitions(x))) + collect(reduceByKey(ones, `+`, getNumPartitions(x))) }) -# Apply a function to all elements -# -# This function creates a new RDD by applying the given transformation to all -# elements of the given RDD -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @rdname lapply -# @aliases lapply -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) -# collect(multiplyByTwo) # 2,4,6... -#} +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @noRd +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} setMethod("lapply", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -464,31 +484,33 @@ setMethod("lapply", lapplyPartitionsWithIndex(X, func) }) -# @rdname lapply -# @aliases map,RDD,function-method +#' @rdname lapply +#' @aliases map,RDD,function-method +#' @noRd setMethod("map", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# Flatten results after apply a function to all elements -# -# This function return a new RDD by first applying a function to all -# elements of this RDD, and then flattening the results. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) -# collect(multiplyByTwo) # 2,20,4,40,6,60... -#} -# @rdname flatMap -# @aliases flatMap,RDD,function-method +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +#' @noRd setMethod("flatMap", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -501,83 +523,88 @@ setMethod("flatMap", lapplyPartition(X, partitionFunc) }) -# Apply a function to each partition of an RDD -# -# Return a new RDD by applying a function to each partition of this RDD. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) -# collect(partitionSum) # 15, 40 -#} -# @rdname lapplyPartition -# @aliases lapplyPartition,RDD,function-method +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +#' @noRd setMethod("lapplyPartition", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) }) -# mapPartitions is the same as lapplyPartition. -# -# @rdname lapplyPartition -# @aliases mapPartitions,RDD,function-method +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +#' @noRd setMethod("mapPartitions", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# Return a new RDD by applying a function to each partition of this RDD, while -# tracking the index of the original partition. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition; takes the partition -# index and a list of elements in the particular partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 5L) -# prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { -# partIndex * Reduce("+", part) }) -# collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 -#} -# @rdname lapplyPartitionsWithIndex -# @aliases lapplyPartitionsWithIndex,RDD,function-method +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("lapplyPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { PipelinedRDD(X, FUN) }) -# @rdname lapplyPartitionsWithIndex -# @aliases mapPartitionsWithIndex,RDD,function-method +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("mapPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, FUN) }) -# This function returns a new RDD containing only the elements that satisfy -# a predicate (i.e. returning TRUE in a given logical function). -# The same as `filter()' in Spark. -# -# @param x The RDD to be filtered. -# @param f A unary predicate function. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) -#} -# @rdname filterRDD -# @aliases filterRDD,RDD,function-method +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +#' @noRd setMethod("filterRDD", signature(x = "RDD", f = "function"), function(x, f) { @@ -587,30 +614,32 @@ setMethod("filterRDD", lapplyPartition(x, filter.func) }) -# @rdname filterRDD -# @aliases Filter +#' @rdname filterRDD +#' @aliases Filter +#' @noRd setMethod("Filter", signature(f = "function", x = "RDD"), function(f, x) { filterRDD(x, f) }) -# Reduce across elements of an RDD. -# -# This function reduces the elements of this RDD using the -# specified commutative and associative binary operator. -# -# @param x The RDD to reduce -# @param func Commutative and associative function to apply on elements -# of the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# reduce(rdd, "+") # 55 -#} -# @rdname reduce -# @aliases reduce,RDD,ANY-method +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +#' @noRd setMethod("reduce", signature(x = "RDD", func = "ANY"), function(x, func) { @@ -624,70 +653,74 @@ setMethod("reduce", Reduce(func, partitionList) }) -# Get the maximum element of an RDD. -# -# @param x The RDD to get the maximum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# maximum(rdd) # 10 -#} -# @rdname maximum -# @aliases maximum,RDD +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +#' @noRd setMethod("maximum", signature(x = "RDD"), function(x) { reduce(x, max) }) -# Get the minimum element of an RDD. -# -# @param x The RDD to get the minimum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# minimum(rdd) # 1 -#} -# @rdname minimum -# @aliases minimum,RDD +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +#' @noRd setMethod("minimum", signature(x = "RDD"), function(x) { reduce(x, min) }) -# Add up the elements in an RDD. -# -# @param x The RDD to add up the elements in -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# sumRDD(rdd) # 55 -#} -# @rdname sumRDD -# @aliases sumRDD,RDD +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +#' @noRd setMethod("sumRDD", signature(x = "RDD"), function(x) { reduce(x, "+") }) -# Applies a function to all elements in an RDD, and force evaluation. -# -# @param x The RDD to apply the function -# @param func The function to be applied. -# @return invisible NULL. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreach(rdd, function(x) { save(x, file=...) }) -#} -# @rdname foreach -# @aliases foreach,RDD,function-method +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +#' @noRd setMethod("foreach", signature(x = "RDD", func = "function"), function(x, func) { @@ -698,44 +731,46 @@ setMethod("foreach", invisible(collect(mapPartitions(x, partition.func))) }) -# Applies a function to each partition in an RDD, and force evaluation. -# -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreachPartition(rdd, function(part) { save(part, file=...); NULL }) -#} -# @rdname foreach -# @aliases foreachPartition,RDD,function-method +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +#' @noRd setMethod("foreachPartition", signature(x = "RDD", func = "function"), function(x, func) { invisible(collect(mapPartitions(x, func))) }) -# Take elements from an RDD. -# -# This function takes the first NUM elements in the RDD and -# returns them in a list. -# -# @param x The RDD to take elements from -# @param num Number of elements to take -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# take(rdd, 2L) # list(1, 2) -#} -# @rdname take -# @aliases take,RDD,numeric-method +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +#' @noRd setMethod("take", signature(x = "RDD", num = "numeric"), function(x, num) { resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- numPartitions(x) + numPartitions <- getNumPartitions(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -763,42 +798,43 @@ setMethod("take", }) -# First -# -# Return the first element of an RDD -# -# @rdname first -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# first(rdd) -# } +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +#' @noRd setMethod("first", signature(x = "RDD"), function(x) { take(x, 1)[[1]] }) -# Removes the duplicates from RDD. -# -# This function returns a new RDD containing the distinct elements in the -# given RDD. The same as `distinct()' in Spark. -# -# @param x The RDD to remove duplicates from. -# @param numPartitions Number of partitions to create. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,2,3,3,3)) -# sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) -#} -# @rdname distinct -# @aliases distinct,RDD-method +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +#' @noRd setMethod("distinct", signature(x = "RDD"), - function(x, numPartitions = SparkR:::numPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -807,24 +843,25 @@ setMethod("distinct", resRDD }) -# Return an RDD that is a sampled subset of the given RDD. -# -# The same as `sample()' in Spark. (We rename it due to signature -# inconsistencies with the `sample()' function in R's base package.) -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements -# collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates -#} -# @rdname sampleRDD -# @aliases sampleRDD,RDD +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +#' @noRd setMethod("sampleRDD", signature(x = "RDD", withReplacement = "logical", fraction = "numeric", seed = "integer"), @@ -868,23 +905,24 @@ setMethod("sampleRDD", lapplyPartitionsWithIndex(x, samplingFunc) }) -# Return a list of the elements that are a sampled subset of the given RDD. -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param num Number of elements to return -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:100) -# # exactly 5 elements sampled, which may not be distinct -# takeSample(rdd, TRUE, 5L, 1618L) -# # exactly 5 distinct elements sampled -# takeSample(rdd, FALSE, 5L, 16181618L) -#} -# @rdname takeSample -# @aliases takeSample,RDD +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +#' @noRd setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", num = "integer", seed = "integer"), function(x, withReplacement, num, seed) { @@ -931,18 +969,19 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", base::sample(samples)[1:total] }) -# Creates tuples of the elements in this RDD by applying a function. -# -# @param x The RDD. -# @param func The function to be applied. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3)) -# collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) -#} -# @rdname keyBy -# @aliases keyBy,RDD +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +#' @noRd setMethod("keyBy", signature(x = "RDD", func = "function"), function(x, func) { @@ -952,49 +991,51 @@ setMethod("keyBy", lapply(x, apply.func) }) -# Return a new RDD that has exactly numPartitions partitions. -# Can increase or decrease the level of parallelism in this RDD. Internally, -# this uses a shuffle to redistribute data. -# If you are decreasing the number of partitions in this RDD, consider using -# coalesce, which can avoid performing a shuffle. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso coalesce -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -# numPartitions(rdd) # 4 -# numPartitions(repartition(rdd, 2L)) # 2 -#} -# @rdname repartition -# @aliases repartition,RDD +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' getNumPartitions(rdd) # 4 +#' getNumPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +#' @noRd setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { coalesce(x, numPartitions, TRUE) }) -# Return a new RDD that is reduced into numPartitions partitions. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso repartition -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -# numPartitions(rdd) # 3 -# numPartitions(coalesce(rdd, 1L)) # 1 -#} -# @rdname coalesce -# @aliases coalesce,RDD +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' getNumPartitions(rdd) # 3 +#' getNumPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +#' @noRd setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1013,19 +1054,20 @@ setMethod("coalesce", } }) -# Save this RDD as a SequenceFile of serialized objects. -# -# @param x The RDD to save -# @param path The directory where the file is saved -# @seealso objectFile -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsObjectFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsObjectFile -# @aliases saveAsObjectFile,RDD +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +#' @noRd setMethod("saveAsObjectFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1038,18 +1080,19 @@ setMethod("saveAsObjectFile", invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) }) -# Save this RDD as a text file, using string representations of elements. -# -# @param x The RDD to save -# @param path The directory where the partitions of the text file are saved -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsTextFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsTextFile -# @aliases saveAsTextFile,RDD +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the partitions of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +#' @noRd setMethod("saveAsTextFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1062,24 +1105,25 @@ setMethod("saveAsTextFile", callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) }) -# Sort an RDD by the given key function. -# -# @param x An RDD to be sorted. -# @param func A function used to compute the sort key for each element. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(3, 2, 1)) -# collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) -#} -# @rdname sortBy -# @aliases sortBy,RDD,RDD-method +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +#' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1111,7 +1155,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- numPartitions(newRdd) + numPartitions <- getNumPartitions(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1138,97 +1182,95 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList } -# Returns the first N elements from an RDD in ascending order. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The first N elements from the RDD in ascending order. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) -#} -# @rdname takeOrdered -# @aliases takeOrdered,RDD,RDD-method +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +#' @noRd setMethod("takeOrdered", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num) }) -# Returns the top N elements from an RDD. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The top N elements from the RDD. -# @rdname top -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) -#} -# @rdname top -# @aliases top,RDD,RDD-method +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @aliases top,RDD,RDD-method +#' @noRd setMethod("top", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num, FALSE) }) -# Fold an RDD using a given associative function and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using a given associative function and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param op An associative function for the folding operation. -# @return The folding result. -# @rdname fold -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) -# fold(rdd, 0, "+") # 15 -#} -# @rdname fold -# @aliases fold,RDD,RDD-method +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @aliases fold,RDD,RDD-method +#' @noRd setMethod("fold", signature(x = "RDD", zeroValue = "ANY", op = "ANY"), function(x, zeroValue, op) { aggregateRDD(x, zeroValue, op, op) }) -# Aggregate an RDD using the given combine functions and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using given combine functions and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the RDD elements. It may return a different -# result type from the type of the RDD elements. -# @param combOp A function to aggregate results of seqOp. -# @return The aggregation result. -# @rdname aggregateRDD -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4)) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) -#} -# @rdname aggregateRDD -# @aliases aggregateRDD,RDD,RDD-method +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @aliases aggregateRDD,RDD,RDD-method +#' @noRd setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), function(x, zeroValue, seqOp, combOp) { @@ -1241,25 +1283,24 @@ setMethod("aggregateRDD", Reduce(combOp, partitionList, zeroValue) }) -# Pipes elements to a forked external process. -# -# The same as 'pipe()' in Spark. -# -# @param x The RDD whose elements are piped to the forked external process. -# @param command The command to fork an external process. -# @param env A named list to set environment variables of the external process. -# @return A new RDD created by piping all elements to a forked external process. -# @rdname pipeRDD -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(pipeRDD(rdd, "more") -# Output: c("1", "2", ..., "10") -#} -# @rdname pipeRDD -# @aliases pipeRDD,RDD,character-method +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @aliases pipeRDD,RDD,character-method +#' @noRd setMethod("pipeRDD", signature(x = "RDD", command = "character"), function(x, command, env = list()) { @@ -1274,42 +1315,40 @@ setMethod("pipeRDD", lapplyPartition(x, func) }) -# TODO: Consider caching the name in the RDD's environment -# Return an RDD's name. -# -# @param x The RDD whose name is returned. -# @rdname name -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# name(rdd) # NULL (if not set before) -#} -# @rdname name -# @aliases name,RDD +#' TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @aliases name,RDD +#' @noRd setMethod("name", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "name") }) -# Set an RDD's name. -# -# @param x The RDD whose name is to be set. -# @param name The RDD name to be set. -# @return a new RDD renamed. -# @rdname setName -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# setName(rdd, "myRDD") -# name(rdd) # "myRDD" -#} -# @rdname setName -# @aliases setName,RDD +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @aliases setName,RDD +#' @noRd setMethod("setName", signature(x = "RDD", name = "character"), function(x, name) { @@ -1317,29 +1356,30 @@ setMethod("setName", x }) -# Zip an RDD with generated unique Long IDs. -# -# Items in the kth partition will get ids k, n+k, 2*n+k, ..., where -# n is the number of partitions. So there may exist gaps, but this -# method won't trigger a spark job, which is different from -# zipWithIndex. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithIndex -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) -# # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) -#} -# @rdname zipWithUniqueId -# @aliases zipWithUniqueId,RDD +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +#' @noRd setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1354,32 +1394,33 @@ setMethod("zipWithUniqueId", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Zip an RDD with its element indices. -# -# The ordering is first based on the partition index and then the -# ordering of items within each partition. So the first item in -# the first partition gets index 0, and the last item in the last -# partition receives the largest index. -# -# This method needs to trigger a Spark job when this RDD contains -# more than one partition. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithUniqueId -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithIndex(rdd)) -# # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) -#} -# @rdname zipWithIndex -# @aliases zipWithIndex,RDD +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +#' @noRd setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) if (n > 1) { nums <- collect(lapplyPartition(x, function(part) { @@ -1407,20 +1448,21 @@ setMethod("zipWithIndex", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Coalesce all elements within each partition of an RDD into a list. -# -# @param x An RDD. -# @return An RDD created by coalescing all elements within -# each partition into a list. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, as.list(1:4), 2L) -# collect(glom(rdd)) -# # list(list(1, 2), list(3, 4)) -#} -# @rdname glom -# @aliases glom,RDD +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +#' @noRd setMethod("glom", signature(x = "RDD"), function(x) { @@ -1433,21 +1475,22 @@ setMethod("glom", ############ Binary Functions ############# -# Return the union RDD of two RDDs. -# The same as union() in Spark. -# -# @param x An RDD. -# @param y An RDD. -# @return a new RDD created by performing the simple union (witout removing -# duplicates) of two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 -#} -# @rdname unionRDD -# @aliases unionRDD,RDD,RDD-method +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +#' @noRd setMethod("unionRDD", signature(x = "RDD", y = "RDD"), function(x, y) { @@ -1464,32 +1507,33 @@ setMethod("unionRDD", union.rdd }) -# Zip an RDD with another RDD. -# -# Zips this RDD with another one, returning key-value pairs with the -# first element in each RDD second element in each RDD, etc. Assumes -# that the two RDDs have the same number of partitions and the same -# number of elements in each partition (e.g. one was made through -# a map on the other). -# -# @param x An RDD to be zipped. -# @param other Another RDD to be zipped. -# @return An RDD zipped from the two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 0:4) -# rdd2 <- parallelize(sc, 1000:1004) -# collect(zipRDD(rdd1, rdd2)) -# # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) -#} -# @rdname zipRDD -# @aliases zipRDD,RDD +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +#' @noRd setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- numPartitions(x) - n2 <- numPartitions(other) + n1 <- getNumPartitions(x) + n2 <- getNumPartitions(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1503,24 +1547,25 @@ setMethod("zipRDD", mergePartitions(rdd, TRUE) }) -# Cartesian product of this RDD and another one. -# -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a -# is in this and b is in other. -# -# @param x An RDD. -# @param other An RDD. -# @return A new RDD which is the Cartesian product of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) -# # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) -#} -# @rdname cartesian -# @aliases cartesian,RDD,RDD-method +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +#' @noRd setMethod("cartesian", signature(x = "RDD", other = "RDD"), function(x, other) { @@ -1533,58 +1578,60 @@ setMethod("cartesian", mergePartitions(rdd, FALSE) }) -# Subtract an RDD with another RDD. -# -# Return an RDD with the elements from this that are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the elements from this that are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) -# rdd2 <- parallelize(sc, list(2, 4)) -# collect(subtract(rdd1, rdd2)) -# # list(1, 1, 3) -#} -# @rdname subtract -# @aliases subtract,RDD +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +#' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) keys(subtractByKey(rdd1, rdd2, numPartitions)) }) -# Intersection of this RDD and another one. -# -# Return the intersection of this RDD and another one. -# The output will not contain any duplicate elements, -# even if the input RDDs did. Performs a hash partition -# across the cluster. -# Note that this method performs a shuffle internally. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions The number of partitions in the result RDD. -# @return An RDD which is the intersection of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) -# rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) -# collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) -# # list(1, 2, 3) -#} -# @rdname intersection -# @aliases intersection,RDD +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +#' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1597,26 +1644,27 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) -# Zips an RDD's partitions with one (or more) RDD(s). -# Same as zipPartitions in Spark. -# -# @param ... RDDs to be zipped. -# @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but -# does *not* require them to have the same number of elements in each partition. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 -# rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 -# rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, -# func = function(x, y, z) { list(list(x, y, z))} )) -# # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) -#} -# @rdname zipRDD -# @aliases zipPartitions,RDD +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +#' @noRd setMethod("zipPartitions", "RDD", function(..., func) { @@ -1624,7 +1672,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, numPartitions) + nPart <- sapply(rrdds, getNumPartitions) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 110117a18ccb..9243d70e66f7 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -17,69 +17,76 @@ # SQLcontext.R: SQLContext-driven functions + +# Map top level R type to SQL type +getInternalType <- function(x) { + # class of POSIXlt is c("POSIXlt" "POSIXt") + switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + struct = "struct", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) +} + #' infer the SQL type infer_type <- function(x) { if (is.null(x)) { stop("can not infer type from NULL") } - # class of POSIXlt is c("POSIXlt" "POSIXt") - type <- switch(class(x)[[1]], - integer = "integer", - character = "string", - logical = "boolean", - double = "double", - numeric = "double", - raw = "binary", - list = "array", - environment = "map", - Date = "date", - POSIXlt = "timestamp", - POSIXct = "timestamp", - stop(paste("Unsupported type for DataFrame:", class(x)))) + type <- getInternalType(x) if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) + + paste0("array<", infer_type(x[[1]]), ">") + } else if (type == "struct") { + stopifnot(length(x) > 0) names <- names(x) - if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) - } else { - # StructType - types <- lapply(x, infer_type) - fields <- lapply(1:length(x), function(i) { - structField(names[[i]], types[[i]], TRUE) - }) - do.call(structType, fields) - } - } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + stopifnot(!is.null(names)) + + type <- lapply(seq_along(x), function(i) { + paste0(names[[i]], ":", infer_type(x[[i]]), ",") + }) + type <- Reduce(paste0, type) + type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") + } else if (length(x) > 1 && type != "binary") { + paste0("array<", infer_type(x[[1]]), ">") } else { type } } -#' Create a DataFrame from an RDD +#' Create a DataFrame #' -#' Converts an RDD to a DataFrame by infer the types. +#' Converts R data.frame or list into DataFrame. #' #' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame +#' @rdname createDataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlContext, rdd) +#' df1 <- as.DataFrame(sqlContext, iris) +#' df2 <- as.DataFrame(sqlContext, list(3,4,5,6)) +#' df3 <- createDataFrame(sqlContext, iris) #' } # TODO(davies): support sampling and infer type from NA @@ -89,19 +96,25 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 if (is.null(schema)) { schema <- names(data) } - n <- nrow(data) - m <- ncol(data) + # get rid of factor type - dropFactor <- function(x) { + cleanCols <- function(x) { if (is.factor(x)) { as.character(x) } else { x } } - data <- lapply(1:n, function(i) { - lapply(1:m, function(j) { dropFactor(data[i,j]) }) - }) + + # drop factors and wrap lists + data <- setNames(lapply(data, cleanCols), NULL) + + # check if all columns have supported type + lapply(data, getInternalType) + + # convert to rows + args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) + data <- do.call(mapply, append(args, data)) } if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) @@ -143,7 +156,6 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 } stopifnot(class(schema) == "structType") - # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") @@ -152,22 +164,28 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 dataFrame(sdf) } -# toDF -# -# Converts an RDD to a DataFrame by infer the types. -# -# @param x An RDD -# -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -# df <- toDF(rdd) -# } +#' @rdname createDataFrame +#' @aliases createDataFrame +#' @export +as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { + createDataFrame(sqlContext, data, schema, samplingRatio) +} +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#'} setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), @@ -190,42 +208,51 @@ setMethod("toDF", signature(x = "RDD"), #' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.json +#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } - -jsonFile <- function(sqlContext, path) { +read.json <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- normalizePath(path) - # Convert a string vector of paths to a string containing comma separated paths - path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlContext, "jsonFile", path) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +#' @rdname read.json +#' @name jsonFile +#' @export +jsonFile <- function(sqlContext, path) { + .Deprecated("read.json") + read.json(sqlContext, path) +} -# JSON RDD -# -# Loads an RDD storing one JSON object per string as a DataFrame. -# -# @param sqlContext SQLContext to use -# @param rdd An RDD of JSON string -# @param schema A StructType object to use as schema -# @param samplingRatio The ratio of simpling used to infer the schema -# @return A DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlContext, rdd) -# } + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlContext, rdd) +#'} # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { @@ -238,20 +265,32 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { } } - #' Create a DataFrame from a Parquet file. #' #' Loads a Parquet file, returning the result as a DataFrame. #' #' @param sqlContext SQLContext to use -#' @param ... Path(s) of parquet file(s) to read. +#' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.parquet +#' @name read.parquet #' @export +read.parquet <- function(sqlContext, path) { + # Allow the user to have a more flexible definiton of the text file path + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "parquet", paths) + dataFrame(sdf) +} +#' @rdname read.parquet +#' @name parquetFile +#' @export # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { + .Deprecated("read.parquet") # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), normalizePath) + paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } @@ -269,7 +308,7 @@ parquetFile <- function(sqlContext, ...) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } @@ -293,7 +332,7 @@ sql <- function(sqlContext, sqlQuery) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- table(sqlContext, "table") #' } @@ -366,7 +405,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' cacheTable(sqlContext, "table") #' } @@ -388,7 +427,7 @@ cacheTable <- function(sqlContext, tableName) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' uncacheTable(sqlContext, "table") #' } @@ -444,14 +483,21 @@ dropTempTable <- function(sqlContext, tableName) { #' #' @param sqlContext SQLContext to use #' @param path The path of files to load -#' @param source the name of external data source +#' @param source The name of external data source +#' @param schema The data schema defined in structType #' @return DataFrame +#' @rdname read.df +#' @name read.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, "path/to/file.json", source = "json") +#' df1 <- read.df(sqlContext, "path/to/file.json", source = "json") +#' schema <- structType(structField("name", "string"), +#' structField("info", "map")) +#' df2 <- read.df(sqlContext, mapTypeJsonPath, "json", schema) +#' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") #' } read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { @@ -474,9 +520,8 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) dataFrame(sdf) } -#' @aliases loadDF -#' @export - +#' @rdname read.df +#' @name loadDF loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { read.df(sqlContext, path, source, schema, ...) } diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c..38f0eed95e06 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index c811d1dac3bd..25e99390a9c8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -44,12 +44,16 @@ determineSparkSubmitBin <- function() { } generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + jars <- paste0(jars, collapse = ",") if (jars != "") { - jars <- paste("--jars", jars) + # construct the jars argument with a space between --jars and comma-separated values + jars <- paste0("--jars ", jars) } - if (!identical(packages, "")) { - packages <- paste("--packages", packages) + packages <- paste0(packages, collapse = ",") + if (packages != "") { + # construct the packages argument with a space between --packages and comma-separated values + packages <- paste0("--packages ", packages) } combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index eeaf9f193b72..7bb8ef2595b5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -24,10 +24,9 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame column #' @description The column class supports unary, binary operations on DataFrame columns - #' @rdname column #' -#' @param jc reference to JVM DataFrame column +#' @slot jc reference to JVM DataFrame column #' @export setClass("Column", slots = list(jc = "jobj")) @@ -37,15 +36,14 @@ setMethod("initialize", "Column", function(.Object, jc) { .Object }) -column <- function(jc) { - new("Column", jc) -} - -col <- function(x) { - column(callJStatic("org.apache.spark.sql.functions", "col", x)) -} +setMethod("column", + signature(x = "jobj"), + function(x) { + new("Column", x) + }) #' @rdname show +#' @name show setMethod("show", "Column", function(object) { cat("Column", callJMethod(object@jc, "toString"), "\n") @@ -58,14 +56,8 @@ operators <- list( "&" = "and", "|" = "or", #, "!" = "unary_$bang" "^" = "pow" ) -column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +103,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +113,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -161,8 +120,11 @@ createMethods() #' alias #' #' Set a new name for a column - -#' @rdname column +#' +#' @rdname alias +#' @name alias +#' @family colum_func +#' @export setMethod("alias", signature(object = "Column"), function(object, data) { @@ -177,7 +139,9 @@ setMethod("alias", #' #' An expression that returns a substring. #' -#' @rdname column +#' @rdname substr +#' @name substr +#' @family colum_func #' #' @param start starting position #' @param stop ending position @@ -191,7 +155,9 @@ setMethod("substr", signature(x = "Column"), #' #' Test if the column is between the lower bound and upper bound, inclusive. #' -#' @rdname column +#' @rdname between +#' @name between +#' @family colum_func #' #' @param bounds lower and upper bounds setMethod("between", signature(x = "Column"), @@ -206,10 +172,11 @@ setMethod("between", signature(x = "Column"), #' Casts the column to a different data type. #' -#' @rdname column +#' @rdname cast +#' @name cast +#' @family colum_func #' -#' @examples -#' \dontrun{ +#' @examples \dontrun{ #' cast(df$age, "string") #' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) #' } @@ -229,58 +196,36 @@ setMethod("cast", #' Match a column with given values. #' -#' @rdname column +#' @rdname match +#' @name %in% +#' @aliases %in% #' @return a matched values as a result of comparing with given values. +#' @export +#' @examples #' \dontrun{ -#' filter(df, "age in (10, 30)") -#' where(df, df$age %in% c(10, 30)) +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) #' } setMethod("%in%", signature(x = "Column"), function(x, table) { - table <- listToSeq(as.list(table)) - jc <- callJMethod(x@jc, "in", table) + jc <- callJMethod(x@jc, "in", as.list(table)) return(column(jc)) }) -#' Approx Count Distinct +#' otherwise #' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", - signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) - }) - -#' Count Distinct +#' If values in the specified column are null, returns the value. +#' Can be used in conjunction with `when` to specify a default value for expressions. #' -#' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) +#' @rdname otherwise +#' @name otherwise +#' @family colum_func +#' @export +setMethod("otherwise", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 720990e1c608..471bec1eacf0 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -25,23 +25,23 @@ getMinPartitions <- function(sc, minPartitions) { as.integer(minPartitions) } -# Create an RDD from a text file. -# -# This function reads a text file from HDFS, a local file system (available on all -# nodes), or any Hadoop-supported file system URI, and creates an -# RDD of strings from it. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD where each item is of type \code{character} -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# lines <- textFile(sc, "myfile.txt") -#} +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} textFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -53,23 +53,23 @@ textFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "string") } -# Load an RDD saved as a SequenceFile containing serialized objects. -# -# The file to be loaded should be one that was previously generated by calling -# saveAsObjectFile() of the RDD class. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD containing serialized R objects. -# @seealso saveAsObjectFile -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- objectFile(sc, "myfile") -#} +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} objectFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -81,24 +81,24 @@ objectFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "byte") } -# Create an RDD from a homogeneous list or vector. -# -# This function creates an RDD from a local homogeneous list in R. The elements -# in the list are split into \code{numSlices} slices and distributed to nodes -# in the cluster. -# -# @param sc SparkContext to use -# @param coll collection to parallelize -# @param numSlices number of partitions to create in the RDD -# @return an RDD created from this collection -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2) -# # The RDD should contain 10 elements -# length(rdd) -#} +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives @@ -133,33 +133,32 @@ parallelize <- function(sc, coll, numSlices = 1) { RDD(jrdd, "byte") } -# Include this specified package on all workers -# -# This function can be used to include a package on all workers before the -# user's code is executed. This is useful in scenarios where other R package -# functions are used in a function passed to functions like \code{lapply}. -# NOTE: The package is assumed to be installed on every node in the Spark -# cluster. -# -# @param sc SparkContext to use -# @param pkg Package name -# -# @export -# @examples -#\dontrun{ -# library(Matrix) -# -# sc <- sparkR.init() -# # Include the matrix library we will be using -# includePackage(sc, Matrix) -# -# generateSparse <- function(x) { -# sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) -# } -# -# rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) -# collect(rdd) -#} +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' @noRd +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} includePackage <- function(sc, pkg) { pkg <- as.character(substitute(pkg)) if (exists(".packages", .sparkREnv)) { @@ -171,30 +170,30 @@ includePackage <- function(sc, pkg) { .sparkREnv$.packages <- packages } -# @title Broadcast a variable to all workers -# -# @description -# Broadcast a read-only variable to the cluster, returning a \code{Broadcast} -# object for reading it in distributed functions. -# -# @param sc Spark Context to use -# @param object Object to be broadcast -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2, 2L) -# -# # Large Matrix object that we want to broadcast -# randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -# randomMatBr <- broadcast(sc, randomMat) -# -# # Use the broadcast variable inside the function -# useBroadcast <- function(x) { -# sum(value(randomMatBr) * x) -# } -# sumRDD <- lapply(rdd, useBroadcast) -#} +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} broadcast <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) @@ -205,21 +204,21 @@ broadcast <- function(sc, object) { Broadcast(id, object, jBroadcast, objName) } -# @title Set the checkpoint directory -# -# Set the directory under which RDDs are going to be checkpointed. The -# directory must be a HDFS path if running on a cluster. -# -# @param sc Spark Context to use -# @param dirName Directory path -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "~/checkpoint") -# rdd <- parallelize(sc, 1:2, 2L) -# checkpoint(rdd) -#} +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} setCheckpointDir <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 6d364f77be7e..f7e56e43016e 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -48,7 +48,10 @@ readTypedObject <- function(con, type) { "r" = readRaw(con), "D" = readDate(con), "t" = readTime(con), + "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), + "s" = readStruct(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -56,8 +59,10 @@ readTypedObject <- function(con, type) { readString <- function(con) { stringLen <- readInt(con) - string <- readBin(con, raw(), stringLen, endian = "big") - rawToChar(string) + raw <- readBin(con, raw(), stringLen, endian = "big") + string <- rawToChar(raw) + Encoding(string) <- "UTF-8" + string } readInt <- function(con) { @@ -85,8 +90,7 @@ readTime <- function(con) { as.POSIXct(t, origin = "1970-01-01") } -# We only support lists where all elements are of same type -readList <- function(con) { +readArray <- function(con) { type <- readType(con) len <- readInt(con) if (len > 0) { @@ -100,6 +104,47 @@ readList <- function(con) { } } +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + +# Read a field of StructType from DataFrame +# into a named list in R whose class is "struct" +readStruct <- function(con) { + names <- readObject(con) + fields <- readObject(con) + names(fields) <- names + listToStruct(fields) +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") @@ -132,18 +177,19 @@ readDeserialize <- function(con) { } } -readDeserializeRows <- function(inputCon) { - # readDeserializeRows will deserialize a DataOutputStream composed of - # a list of lists. Since the DOS is one continuous stream and - # the number of rows varies, we put the readRow function in a while loop - # that termintates when the next row is empty. +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. data <- list() while(TRUE) { - row <- readRow(inputCon) - if (length(row) == 0) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { break } - data[[length(data) + 1L]] <- row + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) } data # this is a list of named lists now } @@ -155,31 +201,5 @@ readRowList <- function(obj) { # deserialize the row. rawObj <- rawConnection(obj, "r+") on.exit(close(rawObj)) - readRow(rawObj) -} - -readRow <- function(inputCon) { - numCols <- readInt(inputCon) - if (length(numCols) > 0 && numCols > 0) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now - } else { - list() - } -} - -# Take a single column as Array[Byte] and deserialize it into an atomic vector -readCol <- function(inputCon, numRows) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) + readObject(rawObj) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 000000000000..09e4e04335a3 --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,2559 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' lit +#' +#' A new \linkS4class{Column} is created to represent the literal value. +#' If the parameter is a \linkS4class{Column}, it is returned unchanged. +#' +#' @family normal_funcs +#' @rdname lit +#' @name lit +#' @export +#' @examples +#' \dontrun{ +#' lit(df$name) +#' select(df, lit("x")) +#' select(df, lit("2015-01-01")) +#'} +setMethod("lit", signature("ANY"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lit", + ifelse(class(x) == "Column", x@jc, x)) + column(jc) + }) + +#' abs +#' +#' Computes the absolute value. +#' +#' @rdname abs +#' @name abs +#' @family normal_funcs +#' @export +#' @examples \dontrun{abs(df$c)} +setMethod("abs", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "abs", x@jc) + column(jc) + }) + +#' acos +#' +#' Computes the cosine inverse of the given value; the returned angle is in the range +#' 0.0 through pi. +#' +#' @rdname acos +#' @name acos +#' @family math_funcs +#' @export +#' @examples \dontrun{acos(df$c)} +setMethod("acos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "acos", x@jc) + column(jc) + }) + +#' approxCountDistinct +#' +#' Aggregate function: returns the approximate number of distinct items in a group. +#' +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{approxCountDistinct(df$c)} +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc) + column(jc) + }) + +#' ascii +#' +#' Computes the numeric value of the first character of the string column, and returns the +#' result as a int column. +#' +#' @rdname ascii +#' @name ascii +#' @family string_funcs +#' @export +#' @examples \dontrun{\dontrun{ascii(df$c)}} +setMethod("ascii", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ascii", x@jc) + column(jc) + }) + +#' asin +#' +#' Computes the sine inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. +#' +#' @rdname asin +#' @name asin +#' @family math_funcs +#' @export +#' @examples \dontrun{asin(df$c)} +setMethod("asin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "asin", x@jc) + column(jc) + }) + +#' atan +#' +#' Computes the tangent inverse of the given value. +#' +#' @rdname atan +#' @name atan +#' @family math_funcs +#' @export +#' @examples \dontrun{atan(df$c)} +setMethod("atan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "atan", x@jc) + column(jc) + }) + +#' avg +#' +#' Aggregate function: returns the average of the values in a group. +#' +#' @rdname avg +#' @name avg +#' @family agg_funcs +#' @export +#' @examples \dontrun{avg(df$c)} +setMethod("avg", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "avg", x@jc) + column(jc) + }) + +#' base64 +#' +#' Computes the BASE64 encoding of a binary column and returns it as a string column. +#' This is the reverse of unbase64. +#' +#' @rdname base64 +#' @name base64 +#' @family string_funcs +#' @export +#' @examples \dontrun{base64(df$c)} +setMethod("base64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "base64", x@jc) + column(jc) + }) + +#' bin +#' +#' An expression that returns the string representation of the binary value of the given long +#' column. For example, bin("12") returns "1100". +#' +#' @rdname bin +#' @name bin +#' @family math_funcs +#' @export +#' @examples \dontrun{bin(df$c)} +setMethod("bin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bin", x@jc) + column(jc) + }) + +#' bitwiseNOT +#' +#' Computes bitwise NOT. +#' +#' @rdname bitwiseNOT +#' @name bitwiseNOT +#' @family normal_funcs +#' @export +#' @examples \dontrun{bitwiseNOT(df$c)} +setMethod("bitwiseNOT", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc) + column(jc) + }) + +#' cbrt +#' +#' Computes the cube-root of the given value. +#' +#' @rdname cbrt +#' @name cbrt +#' @family math_funcs +#' @export +#' @examples \dontrun{cbrt(df$c)} +setMethod("cbrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cbrt", x@jc) + column(jc) + }) + +#' ceil +#' +#' Computes the ceiling of the given value. +#' +#' @rdname ceil +#' @name ceil +#' @family math_funcs +#' @export +#' @examples \dontrun{ceil(df$c)} +setMethod("ceil", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ceil", x@jc) + column(jc) + }) + +#' Though scala functions has "col" function, we don't expose it in SparkR +#' because we don't want to conflict with the "col" function in the R base +#' package and we also have "column" function exported which is an alias of "col". +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' column +#' +#' Returns a Column based on the given column name. +#' +#' @rdname col +#' @name column +#' @family normal_funcs +#' @export +#' @examples \dontrun{column(df)} +setMethod("column", + signature(x = "character"), + function(x) { + col(x) + }) +#' corr +#' +#' Computes the Pearson Correlation Coefficient for two Columns. +#' +#' @rdname corr +#' @name corr +#' @family math_funcs +#' @export +#' @examples \dontrun{corr(df$c, df$d)} +setMethod("corr", signature(x = "Column"), + function(x, col2) { + stopifnot(class(col2) == "Column") + jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc) + column(jc) + }) + +#' cos +#' +#' Computes the cosine of the given value. +#' +#' @rdname cos +#' @name cos +#' @family math_funcs +#' @export +#' @examples \dontrun{cos(df$c)} +setMethod("cos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cos", x@jc) + column(jc) + }) + +#' cosh +#' +#' Computes the hyperbolic cosine of the given value. +#' +#' @rdname cosh +#' @name cosh +#' @family math_funcs +#' @export +#' @examples \dontrun{cosh(df$c)} +setMethod("cosh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cosh", x@jc) + column(jc) + }) + +#' count +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @rdname count +#' @name count +#' @family agg_funcs +#' @export +#' @examples \dontrun{count(df$c)} +setMethod("count", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "count", x@jc) + column(jc) + }) + +#' crc32 +#' +#' Calculates the cyclic redundancy check value (CRC32) of a binary column and +#' returns the value as a bigint. +#' +#' @rdname crc32 +#' @name crc32 +#' @family misc_funcs +#' @export +#' @examples \dontrun{crc32(df$c)} +setMethod("crc32", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "crc32", x@jc) + column(jc) + }) + +#' dayofmonth +#' +#' Extracts the day of the month as an integer from a given date/timestamp/string. +#' +#' @rdname dayofmonth +#' @name dayofmonth +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofmonth(df$c)} +setMethod("dayofmonth", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofmonth", x@jc) + column(jc) + }) + +#' dayofyear +#' +#' Extracts the day of the year as an integer from a given date/timestamp/string. +#' +#' @rdname dayofyear +#' @name dayofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofyear(df$c)} +setMethod("dayofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofyear", x@jc) + column(jc) + }) + +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + +#' exp +#' +#' Computes the exponential of the given value. +#' +#' @rdname exp +#' @name exp +#' @family math_funcs +#' @export +#' @examples \dontrun{exp(df$c)} +setMethod("exp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "exp", x@jc) + column(jc) + }) + +#' expm1 +#' +#' Computes the exponential of the given value minus one. +#' +#' @rdname expm1 +#' @name expm1 +#' @family math_funcs +#' @export +#' @examples \dontrun{expm1(df$c)} +setMethod("expm1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expm1", x@jc) + column(jc) + }) + +#' factorial +#' +#' Computes the factorial of the given value. +#' +#' @rdname factorial +#' @name factorial +#' @family math_funcs +#' @export +#' @examples \dontrun{factorial(df$c)} +setMethod("factorial", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "factorial", x@jc) + column(jc) + }) + +#' first +#' +#' Aggregate function: returns the first value in a group. +#' +#' @rdname first +#' @name first +#' @family agg_funcs +#' @export +#' @examples \dontrun{first(df$c)} +setMethod("first", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + column(jc) + }) + +#' floor +#' +#' Computes the floor of the given value. +#' +#' @rdname floor +#' @name floor +#' @family math_funcs +#' @export +#' @examples \dontrun{floor(df$c)} +setMethod("floor", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "floor", x@jc) + column(jc) + }) + +#' hex +#' +#' Computes hex value of the given column. +#' +#' @rdname hex +#' @name hex +#' @family math_funcs +#' @export +#' @examples \dontrun{hex(df$c)} +setMethod("hex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hex", x@jc) + column(jc) + }) + +#' hour +#' +#' Extracts the hours as an integer from a given date/timestamp/string. +#' +#' @rdname hour +#' @name hour +#' @family datetime_funcs +#' @export +#' @examples \dontrun{hour(df$c)} +setMethod("hour", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hour", x@jc) + column(jc) + }) + +#' initcap +#' +#' Returns a new string column by converting the first letter of each word to uppercase. +#' Words are delimited by whitespace. +#' +#' For example, "hello world" will become "Hello World". +#' +#' @rdname initcap +#' @name initcap +#' @family string_funcs +#' @export +#' @examples \dontrun{initcap(df$c)} +setMethod("initcap", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "initcap", x@jc) + column(jc) + }) + +#' is.nan +#' +#' Return true if the column is NaN, alias for \link{isnan} +#' +#' @rdname is.nan +#' @name is.nan +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' is.nan(df$c) +#' isnan(df$c) +#' } +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @rdname is.nan +#' @name isnan +setMethod("isnan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) + column(jc) + }) + +#' kurtosis +#' +#' Aggregate function: returns the kurtosis of the values in a group. +#' +#' @rdname kurtosis +#' @name kurtosis +#' @family agg_funcs +#' @export +#' @examples \dontrun{kurtosis(df$c)} +setMethod("kurtosis", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "kurtosis", x@jc) + column(jc) + }) + +#' last +#' +#' Aggregate function: returns the last value in a group. +#' +#' @rdname last +#' @name last +#' @family agg_funcs +#' @export +#' @examples \dontrun{last(df$c)} +setMethod("last", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + column(jc) + }) + +#' last_day +#' +#' Given a date column, returns the last day of the month which the given date belongs to. +#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the +#' month in July 2015. +#' +#' @rdname last_day +#' @name last_day +#' @family datetime_funcs +#' @export +#' @examples \dontrun{last_day(df$c)} +setMethod("last_day", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last_day", x@jc) + column(jc) + }) + +#' length +#' +#' Computes the length of a given string or binary column. +#' +#' @rdname length +#' @name length +#' @family string_funcs +#' @export +#' @examples \dontrun{length(df$c)} +setMethod("length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "length", x@jc) + column(jc) + }) + +#' log +#' +#' Computes the natural logarithm of the given value. +#' +#' @rdname log +#' @name log +#' @family math_funcs +#' @export +#' @examples \dontrun{log(df$c)} +setMethod("log", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log", x@jc) + column(jc) + }) + +#' log10 +#' +#' Computes the logarithm of the given value in base 10. +#' +#' @rdname log10 +#' @name log10 +#' @family math_funcs +#' @export +#' @examples \dontrun{log10(df$c)} +setMethod("log10", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log10", x@jc) + column(jc) + }) + +#' log1p +#' +#' Computes the natural logarithm of the given value plus one. +#' +#' @rdname log1p +#' @name log1p +#' @family math_funcs +#' @export +#' @examples \dontrun{log1p(df$c)} +setMethod("log1p", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log1p", x@jc) + column(jc) + }) + +#' log2 +#' +#' Computes the logarithm of the given column in base 2. +#' +#' @rdname log2 +#' @name log2 +#' @family math_funcs +#' @export +#' @examples \dontrun{log2(df$c)} +setMethod("log2", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log2", x@jc) + column(jc) + }) + +#' lower +#' +#' Converts a string column to lower case. +#' +#' @rdname lower +#' @name lower +#' @family string_funcs +#' @export +#' @examples \dontrun{lower(df$c)} +setMethod("lower", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "lower", x@jc) + column(jc) + }) + +#' ltrim +#' +#' Trim the spaces from left end for the specified string value. +#' +#' @rdname ltrim +#' @name ltrim +#' @family string_funcs +#' @export +#' @examples \dontrun{ltrim(df$c)} +setMethod("ltrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc) + column(jc) + }) + +#' max +#' +#' Aggregate function: returns the maximum value of the expression in a group. +#' +#' @rdname max +#' @name max +#' @family agg_funcs +#' @export +#' @examples \dontrun{max(df$c)} +setMethod("max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "max", x@jc) + column(jc) + }) + +#' md5 +#' +#' Calculates the MD5 digest of a binary column and returns the value +#' as a 32 character hex string. +#' +#' @rdname md5 +#' @name md5 +#' @family misc_funcs +#' @export +#' @examples \dontrun{md5(df$c)} +setMethod("md5", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "md5", x@jc) + column(jc) + }) + +#' mean +#' +#' Aggregate function: returns the average of the values in a group. +#' Alias for avg. +#' +#' @rdname mean +#' @name mean +#' @family agg_funcs +#' @export +#' @examples \dontrun{mean(df$c)} +setMethod("mean", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "mean", x@jc) + column(jc) + }) + +#' min +#' +#' Aggregate function: returns the minimum value of the expression in a group. +#' +#' @rdname min +#' @name min +#' @family agg_funcs +#' @export +#' @examples \dontrun{min(df$c)} +setMethod("min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "min", x@jc) + column(jc) + }) + +#' minute +#' +#' Extracts the minutes as an integer from a given date/timestamp/string. +#' +#' @rdname minute +#' @name minute +#' @family datetime_funcs +#' @export +#' @examples \dontrun{minute(df$c)} +setMethod("minute", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "minute", x@jc) + column(jc) + }) + +#' month +#' +#' Extracts the month as an integer from a given date/timestamp/string. +#' +#' @rdname month +#' @name month +#' @family datetime_funcs +#' @export +#' @examples \dontrun{month(df$c)} +setMethod("month", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "month", x@jc) + column(jc) + }) + +#' negate +#' +#' Unary minus, i.e. negate the expression. +#' +#' @rdname negate +#' @name negate +#' @family normal_funcs +#' @export +#' @examples \dontrun{negate(df$c)} +setMethod("negate", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "negate", x@jc) + column(jc) + }) + +#' quarter +#' +#' Extracts the quarter as an integer from a given date/timestamp/string. +#' +#' @rdname quarter +#' @name quarter +#' @family datetime_funcs +#' @export +#' @examples \dontrun{quarter(df$c)} +setMethod("quarter", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "quarter", x@jc) + column(jc) + }) + +#' reverse +#' +#' Reverses the string column and returns it as a new string column. +#' +#' @rdname reverse +#' @name reverse +#' @family string_funcs +#' @export +#' @examples \dontrun{reverse(df$c)} +setMethod("reverse", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "reverse", x@jc) + column(jc) + }) + +#' rint +#' +#' Returns the double value that is closest in value to the argument and +#' is equal to a mathematical integer. +#' +#' @rdname rint +#' @name rint +#' @family math_funcs +#' @export +#' @examples \dontrun{rint(df$c)} +setMethod("rint", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rint", x@jc) + column(jc) + }) + +#' round +#' +#' Returns the value of the column `e` rounded to 0 decimal places. +#' +#' @rdname round +#' @name round +#' @family math_funcs +#' @export +#' @examples \dontrun{round(df$c)} +setMethod("round", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "round", x@jc) + column(jc) + }) + +#' rtrim +#' +#' Trim the spaces from right end for the specified string value. +#' +#' @rdname rtrim +#' @name rtrim +#' @family string_funcs +#' @export +#' @examples \dontrun{rtrim(df$c)} +setMethod("rtrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc) + column(jc) + }) + +#' sd +#' +#' Aggregate function: alias for \link{stddev_samp} +#' +#' @rdname sd +#' @name sd +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @export +#' @examples +#'\dontrun{ +#'stddev(df$c) +#'select(df, stddev(df$age)) +#'agg(df, sd(df$age)) +#'} +setMethod("sd", + signature(x = "Column"), + function(x) { + # In R, sample standard deviation is calculated with the sd() function. + stddev_samp(x) + }) + +#' second +#' +#' Extracts the seconds as an integer from a given date/timestamp/string. +#' +#' @rdname second +#' @name second +#' @family datetime_funcs +#' @export +#' @examples \dontrun{second(df$c)} +setMethod("second", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "second", x@jc) + column(jc) + }) + +#' sha1 +#' +#' Calculates the SHA-1 digest of a binary column and returns the value +#' as a 40 character hex string. +#' +#' @rdname sha1 +#' @name sha1 +#' @family misc_funcs +#' @export +#' @examples \dontrun{sha1(df$c)} +setMethod("sha1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha1", x@jc) + column(jc) + }) + +#' signum +#' +#' Computes the signum of the given value. +#' +#' @rdname signum +#' @name signum +#' @family math_funcs +#' @export +#' @examples \dontrun{signum(df$c)} +setMethod("signum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "signum", x@jc) + column(jc) + }) + +#' sin +#' +#' Computes the sine of the given value. +#' +#' @rdname sin +#' @name sin +#' @family math_funcs +#' @export +#' @examples \dontrun{sin(df$c)} +setMethod("sin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sin", x@jc) + column(jc) + }) + +#' sinh +#' +#' Computes the hyperbolic sine of the given value. +#' +#' @rdname sinh +#' @name sinh +#' @family math_funcs +#' @export +#' @examples \dontrun{sinh(df$c)} +setMethod("sinh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sinh", x@jc) + column(jc) + }) + +#' skewness +#' +#' Aggregate function: returns the skewness of the values in a group. +#' +#' @rdname skewness +#' @name skewness +#' @family agg_funcs +#' @export +#' @examples \dontrun{skewness(df$c)} +setMethod("skewness", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "skewness", x@jc) + column(jc) + }) + +#' soundex +#' +#' Return the soundex code for the specified expression. +#' +#' @rdname soundex +#' @name soundex +#' @family string_funcs +#' @export +#' @examples \dontrun{soundex(df$c)} +setMethod("soundex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "soundex", x@jc) + column(jc) + }) + +#' @rdname sd +#' @name stddev +setMethod("stddev", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev", x@jc) + column(jc) + }) + +#' stddev_pop +#' +#' Aggregate function: returns the population standard deviation of the expression in a group. +#' +#' @rdname stddev_pop +#' @name stddev_pop +#' @family agg_funcs +#' @seealso \link{sd}, \link{stddev_samp} +#' @export +#' @examples \dontrun{stddev_pop(df$c)} +setMethod("stddev_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_pop", x@jc) + column(jc) + }) + +#' stddev_samp +#' +#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. +#' +#' @rdname stddev_samp +#' @name stddev_samp +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{sd} +#' @export +#' @examples \dontrun{stddev_samp(df$c)} +setMethod("stddev_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_samp", x@jc) + column(jc) + }) + +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + +#' sqrt +#' +#' Computes the square root of the specified float value. +#' +#' @rdname sqrt +#' @name sqrt +#' @family math_funcs +#' @export +#' @examples \dontrun{sqrt(df$c)} +setMethod("sqrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sqrt", x@jc) + column(jc) + }) + +#' sum +#' +#' Aggregate function: returns the sum of all values in the expression. +#' +#' @rdname sum +#' @name sum +#' @family agg_funcs +#' @export +#' @examples \dontrun{sum(df$c)} +setMethod("sum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sum", x@jc) + column(jc) + }) + +#' sumDistinct +#' +#' Aggregate function: returns the sum of distinct values in the expression. +#' +#' @rdname sumDistinct +#' @name sumDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{sumDistinct(df$c)} +setMethod("sumDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc) + column(jc) + }) + +#' tan +#' +#' Computes the tangent of the given value. +#' +#' @rdname tan +#' @name tan +#' @family math_funcs +#' @export +#' @examples \dontrun{tan(df$c)} +setMethod("tan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tan", x@jc) + column(jc) + }) + +#' tanh +#' +#' Computes the hyperbolic tangent of the given value. +#' +#' @rdname tanh +#' @name tanh +#' @family math_funcs +#' @export +#' @examples \dontrun{tanh(df$c)} +setMethod("tanh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tanh", x@jc) + column(jc) + }) + +#' toDegrees +#' +#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. +#' +#' @rdname toDegrees +#' @name toDegrees +#' @family math_funcs +#' @export +#' @examples \dontrun{toDegrees(df$c)} +setMethod("toDegrees", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toDegrees", x@jc) + column(jc) + }) + +#' toRadians +#' +#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. +#' +#' @rdname toRadians +#' @name toRadians +#' @family math_funcs +#' @export +#' @examples \dontrun{toRadians(df$c)} +setMethod("toRadians", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toRadians", x@jc) + column(jc) + }) + +#' to_date +#' +#' Converts the column into DateType. +#' +#' @rdname to_date +#' @name to_date +#' @family datetime_funcs +#' @export +#' @examples \dontrun{to_date(df$c)} +setMethod("to_date", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc) + column(jc) + }) + +#' trim +#' +#' Trim the spaces from both ends for the specified string column. +#' +#' @rdname trim +#' @name trim +#' @family string_funcs +#' @export +#' @examples \dontrun{trim(df$c)} +setMethod("trim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc) + column(jc) + }) + +#' unbase64 +#' +#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' This is the reverse of base64. +#' +#' @rdname unbase64 +#' @name unbase64 +#' @family string_funcs +#' @export +#' @examples \dontrun{unbase64(df$c)} +setMethod("unbase64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unbase64", x@jc) + column(jc) + }) + +#' unhex +#' +#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' and converts to the byte representation of number. +#' +#' @rdname unhex +#' @name unhex +#' @family math_funcs +#' @export +#' @examples \dontrun{unhex(df$c)} +setMethod("unhex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unhex", x@jc) + column(jc) + }) + +#' upper +#' +#' Converts a string column to upper case. +#' +#' @rdname upper +#' @name upper +#' @family string_funcs +#' @export +#' @examples \dontrun{upper(df$c)} +setMethod("upper", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "upper", x@jc) + column(jc) + }) + +#' var +#' +#' Aggregate function: alias for \link{var_samp}. +#' +#' @rdname var +#' @name var +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var_samp} +#' @export +#' @examples +#'\dontrun{ +#'variance(df$c) +#'select(df, var_pop(df$age)) +#'agg(df, var(df$age)) +#'} +setMethod("var", + signature(x = "Column"), + function(x) { + # In R, sample variance is calculated with the var() function. + var_samp(x) + }) + +#' @rdname var +#' @name variance +setMethod("variance", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "variance", x@jc) + column(jc) + }) + +#' var_pop +#' +#' Aggregate function: returns the population variance of the values in a group. +#' +#' @rdname var_pop +#' @name var_pop +#' @family agg_funcs +#' @seealso \link{var}, \link{var_samp} +#' @export +#' @examples \dontrun{var_pop(df$c)} +setMethod("var_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_pop", x@jc) + column(jc) + }) + +#' var_samp +#' +#' Aggregate function: returns the unbiased variance of the values in a group. +#' +#' @rdname var_samp +#' @name var_samp +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var} +#' @export +#' @examples \dontrun{var_samp(df$c)} +setMethod("var_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_samp", x@jc) + column(jc) + }) + +#' weekofyear +#' +#' Extracts the week number as an integer from a given date/timestamp/string. +#' +#' @rdname weekofyear +#' @name weekofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{weekofyear(df$c)} +setMethod("weekofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "weekofyear", x@jc) + column(jc) + }) + +#' year +#' +#' Extracts the year as an integer from a given date/timestamp/string. +#' +#' @rdname year +#' @name year +#' @family datetime_funcs +#' @export +#' @examples \dontrun{year(df$c)} +setMethod("year", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "year", x@jc) + column(jc) + }) + +#' atan2 +#' +#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to +#' polar coordinates (r, theta). +#' +#' @rdname atan2 +#' @name atan2 +#' @family math_funcs +#' @export +#' @examples \dontrun{atan2(df$c, x)} +setMethod("atan2", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "atan2", y@jc, x) + column(jc) + }) + +#' datediff +#' +#' Returns the number of days from `start` to `end`. +#' +#' @rdname datediff +#' @name datediff +#' @family datetime_funcs +#' @export +#' @examples \dontrun{datediff(df$c, x)} +setMethod("datediff", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "datediff", y@jc, x) + column(jc) + }) + +#' hypot +#' +#' Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. +#' +#' @rdname hypot +#' @name hypot +#' @family math_funcs +#' @export +#' @examples \dontrun{hypot(df$c, x)} +setMethod("hypot", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "hypot", y@jc, x) + column(jc) + }) + +#' levenshtein +#' +#' Computes the Levenshtein distance of the two given string columns. +#' +#' @rdname levenshtein +#' @name levenshtein +#' @family string_funcs +#' @export +#' @examples \dontrun{levenshtein(df$c, x)} +setMethod("levenshtein", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "levenshtein", y@jc, x) + column(jc) + }) + +#' months_between +#' +#' Returns number of months between dates `date1` and `date2`. +#' +#' @rdname months_between +#' @name months_between +#' @family datetime_funcs +#' @export +#' @examples \dontrun{months_between(df$c, x)} +setMethod("months_between", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x) + column(jc) + }) + +#' nanvl +#' +#' Returns col1 if it is not NaN, or col2 if col1 is NaN. +#' hhBoth inputs should be floating point columns (DoubleType or FloatType). +#' +#' @rdname nanvl +#' @name nanvl +#' @family normal_funcs +#' @export +#' @examples \dontrun{nanvl(df$c, x)} +setMethod("nanvl", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "nanvl", y@jc, x) + column(jc) + }) + +#' pmod +#' +#' Returns the positive value of dividend mod divisor. +#' +#' @rdname pmod +#' @name pmod +#' @docType methods +#' @family math_funcs +#' @export +#' @examples \dontrun{pmod(df$c, x)} +setMethod("pmod", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "pmod", y@jc, x) + column(jc) + }) + + +#' Approx Count Distinct +#' +#' @family agg_funcs +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @return the approximate number of distinct items in a group. +#' @export +#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.05) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct +#' @return the number of distinct items in a group. +#' @export +#' @examples \dontrun{countDistinct(df$c)} +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + jcols) + column(jc) + }) + + +#' concat +#' +#' Concatenates multiple input string columns together into a single string column. +#' +#' @family string_funcs +#' @rdname concat +#' @name concat +#' @export +#' @examples \dontrun{concat(df$strings, df$strings2)} +setMethod("concat", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) + column(jc) + }) + +#' greatest +#' +#' Returns the greatest value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' +#' @family normal_funcs +#' @rdname greatest +#' @name greatest +#' @export +#' @examples \dontrun{greatest(df$c, df$d)} +setMethod("greatest", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) + column(jc) + }) + +#' least +#' +#' Returns the least value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' +#' @family normal_funcs +#' @rdname least +#' @name least +#' @export +#' @examples \dontrun{least(df$c, df$d)} +setMethod("least", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) + column(jc) + }) + +#' ceiling +#' +#' Computes the ceiling of the given value. +#' +#' @rdname ceil +#' @name ceiling +#' @export +#' @examples \dontrun{ceiling(df$c)} +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' sign +#' +#' Computes the signum of the given value. +#' +#' @rdname signum +#' @name sign +#' @export +#' @examples \dontrun{sign(df$c)} +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' n_distinct +#' +#' Aggregate function: returns the number of distinct items in a group. +#' +#' @rdname countDistinct +#' @name n_distinct +#' @export +#' @examples \dontrun{n_distinct(df$c)} +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' n +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @rdname count +#' @name n +#' @export +#' @examples \dontrun{n(df$c)} +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) + +#' date_format +#' +#' Converts a date/timestamp/string to a value of string in the format specified by the date +#' format given by the second argument. +#' +#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' pattern letters of \code{java.text.SimpleDateFormat} can be used. +#' +#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' specialized implementation. +#' +#' @family datetime_funcs +#' @rdname date_format +#' @name date_format +#' @export +#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} +setMethod("date_format", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) + column(jc) + }) + +#' from_utc_timestamp +#' +#' Assumes given timestamp is UTC and converts to given timezone. +#' +#' @family datetime_funcs +#' @rdname from_utc_timestamp +#' @name from_utc_timestamp +#' @export +#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} +setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) + column(jc) + }) + +#' instr +#' +#' Locate the position of the first occurrence of substr column in the given string. +#' Returns null if either of the arguments are null. +#' +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname instr +#' @name instr +#' @export +#' @examples \dontrun{instr(df$c, 'b')} +setMethod("instr", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) + column(jc) + }) + +#' next_day +#' +#' Given a date column, returns the first date which is later than the value of the date column +#' that is on the specified day of the week. +#' +#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first +#' Sunday after 2015-07-27. +#' +#' Day of the week parameter is case insensitive, and accepts first three or two characters: +#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' +#' @family datetime_funcs +#' @rdname next_day +#' @name next_day +#' @export +#' @examples +#'\dontrun{ +#'next_day(df$d, 'Sun') +#'next_day(df$d, 'Sunday') +#'} +setMethod("next_day", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) + column(jc) + }) + +#' to_utc_timestamp +#' +#' Assumes given timestamp is in given timezone and converts to UTC. +#' +#' @family datetime_funcs +#' @rdname to_utc_timestamp +#' @name to_utc_timestamp +#' @export +#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} +setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) + column(jc) + }) + +#' add_months +#' +#' Returns the date that is numMonths after startDate. +#' +#' @name add_months +#' @family datetime_funcs +#' @rdname add_months +#' @export +#' @examples \dontrun{add_months(df$d, 1)} +setMethod("add_months", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) + column(jc) + }) + +#' date_add +#' +#' Returns the date that is `days` days after `start` +#' +#' @family datetime_funcs +#' @rdname date_add +#' @name date_add +#' @export +#' @examples \dontrun{date_add(df$d, 1)} +setMethod("date_add", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) + column(jc) + }) + +#' date_sub +#' +#' Returns the date that is `days` days before `start` +#' +#' @family datetime_funcs +#' @rdname date_sub +#' @name date_sub +#' @export +#' @examples \dontrun{date_sub(df$d, 1)} +setMethod("date_sub", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) + column(jc) + }) + +#' format_number +#' +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, +#' and returns the result as a string column. +#' +#' If x is 0, the result has no decimal point or fractional part. +#' If x < 0, the result will be null. +#' +#' @param y column to format +#' @param x number of decimal place to format to +#' @family string_funcs +#' @rdname format_number +#' @name format_number +#' @export +#' @examples \dontrun{format_number(df$n, 4)} +setMethod("format_number", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "format_number", + y@jc, as.integer(x)) + column(jc) + }) + +#' sha2 +#' +#' Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. +#' +#' @param y column to compute SHA-2 on. +#' @param x one of 224, 256, 384, or 512. +#' @family misc_funcs +#' @rdname sha2 +#' @name sha2 +#' @export +#' @examples \dontrun{sha2(df$c, 256)} +setMethod("sha2", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) + column(jc) + }) + +#' shiftLeft +#' +#' Shift the the given value numBits left. If the given value is a long value, this function +#' will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftLeft +#' @name shiftLeft +#' @export +#' @examples \dontrun{shiftLeft(df$c, 1)} +setMethod("shiftLeft", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftLeft", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRight +#' +#' Shift the the given value numBits right. If the given value is a long value, it will return +#' a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRight +#' @name shiftRight +#' @export +#' @examples \dontrun{shiftRight(df$c, 1)} +setMethod("shiftRight", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRight", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRightUnsigned +#' +#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRightUnsigned +#' @name shiftRightUnsigned +#' @export +#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} +setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRightUnsigned", + y@jc, as.integer(x)) + column(jc) + }) + +#' concat_ws +#' +#' Concatenates multiple input string columns together into a single string column, +#' using the given separator. +#' +#' @family string_funcs +#' @rdname concat_ws +#' @name concat_ws +#' @export +#' @examples \dontrun{concat_ws('-', df$s, df$d)} +setMethod("concat_ws", signature(sep = "character", x = "Column"), + function(sep, x, ...) { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) + column(jc) + }) + +#' conv +#' +#' Convert a number in a string column from one base to another. +#' +#' @family math_funcs +#' @rdname conv +#' @name conv +#' @export +#' @examples \dontrun{conv(df$n, 2, 16)} +setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), + function(x, fromBase, toBase) { + fromBase <- as.integer(fromBase) + toBase <- as.integer(toBase) + jc <- callJStatic("org.apache.spark.sql.functions", + "conv", + x@jc, fromBase, toBase) + column(jc) + }) + +#' expr +#' +#' Parses the expression string into the column that it represents, similar to +#' DataFrame.selectExpr +#' +#' @family normal_funcs +#' @rdname expr +#' @name expr +#' @export +#' @examples \dontrun{expr('length(name)')} +setMethod("expr", signature(x = "character"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) + column(jc) + }) + +#' format_string +#' +#' Formats the arguments in printf-style and returns the result as a string column. +#' +#' @family string_funcs +#' @rdname format_string +#' @name format_string +#' @export +#' @examples \dontrun{format_string('%d %s', df$a, df$b)} +setMethod("format_string", signature(format = "character", x = "Column"), + function(format, x, ...) { + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", + "format_string", + format, jcols) + column(jc) + }) + +#' from_unixtime +#' +#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string +#' representing the timestamp of that moment in the current system time zone in the given +#' format. +#' +#' @family datetime_funcs +#' @rdname from_unixtime +#' @name from_unixtime +#' @export +#' @examples +#'\dontrun{ +#'from_unixtime(df$t) +#'from_unixtime(df$t, 'yyyy/MM/dd HH') +#'} +setMethod("from_unixtime", signature(x = "Column"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", + "from_unixtime", + x@jc, format) + column(jc) + }) + +#' locate +#' +#' Locate the position of the first occurrence of substr. +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname locate +#' @name locate +#' @export +#' @examples \dontrun{locate('b', df$c, 1)} +setMethod("locate", signature(substr = "character", str = "Column"), + function(substr, str, pos = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", + "locate", + substr, str@jc, as.integer(pos)) + column(jc) + }) + +#' lpad +#' +#' Left-pad the string column with +#' +#' @family string_funcs +#' @rdname lpad +#' @name lpad +#' @export +#' @examples \dontrun{lpad(df$c, 6, '#')} +setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' rand +#' +#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export +#' @examples \dontrun{rand()} +setMethod("rand", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand") + column(jc) + }) + +#' @rdname rand +#' @name rand +#' @export +setMethod("rand", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) + column(jc) + }) + +#' randn +#' +#' Generate a column with i.i.d. samples from the standard normal distribution. +#' +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export +#' @examples \dontrun{randn()} +setMethod("randn", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn") + column(jc) + }) + +#' @rdname randn +#' @name randn +#' @export +setMethod("randn", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) + column(jc) + }) + +#' regexp_extract +#' +#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' +#' @family string_funcs +#' @rdname regexp_extract +#' @name regexp_extract +#' @export +#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +setMethod("regexp_extract", + signature(x = "Column", pattern = "character", idx = "numeric"), + function(x, pattern, idx) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_extract", + x@jc, pattern, as.integer(idx)) + column(jc) + }) + +#' regexp_replace +#' +#' Replace all substrings of the specified string value that match regexp with rep. +#' +#' @family string_funcs +#' @rdname regexp_replace +#' @name regexp_replace +#' @export +#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} +setMethod("regexp_replace", + signature(x = "Column", pattern = "character", replacement = "character"), + function(x, pattern, replacement) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_replace", + x@jc, pattern, replacement) + column(jc) + }) + +#' rpad +#' +#' Right-padded with pad to a length of len. +#' +#' @family string_funcs +#' @rdname rpad +#' @name rpad +#' @export +#' @examples \dontrun{rpad(df$c, 6, '#')} +setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "rpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' substring_index +#' +#' Returns the substring from string str before count occurrences of the delimiter delim. +#' If count is positive, everything the left of the final delimiter (counting from left) is +#' returned. If count is negative, every to the right of the final delimiter (counting from the +#' right) is returned. substring_index performs a case-sensitive match when searching for delim. +#' +#' @family string_funcs +#' @rdname substring_index +#' @name substring_index +#' @export +#' @examples +#'\dontrun{ +#'substring_index(df$c, '.', 2) +#'substring_index(df$c, '.', -1) +#'} +setMethod("substring_index", + signature(x = "Column", delim = "character", count = "numeric"), + function(x, delim, count) { + jc <- callJStatic("org.apache.spark.sql.functions", + "substring_index", + x@jc, delim, as.integer(count)) + column(jc) + }) + +#' translate +#' +#' Translate any character in the src by a character in replaceString. +#' The characters in replaceString is corresponding to the characters in matchingString. +#' The translate will happen when any character in the string matching with the character +#' in the matchingString. +#' +#' @family string_funcs +#' @rdname translate +#' @name translate +#' @export +#' @examples \dontrun{translate(df$c, 'rnlt', '123')} +setMethod("translate", + signature(x = "Column", matchingString = "character", replaceString = "character"), + function(x, matchingString, replaceString) { + jc <- callJStatic("org.apache.spark.sql.functions", + "translate", x@jc, matchingString, replaceString) + column(jc) + }) + +#' unix_timestamp +#' +#' Gets current Unix timestamp in seconds. +#' +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +#' @examples +#'\dontrun{ +#'unix_timestamp() +#'unix_timestamp(df$t) +#'unix_timestamp(df$t, 'yyyy-MM-dd HH') +#'} +setMethod("unix_timestamp", signature(x = "missing", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") + column(jc) + }) + +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) + column(jc) + }) + +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "character"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) + column(jc) + }) +#' when +#' +#' Evaluates a list of conditions and returns one of multiple possible result expressions. +#' For unmatched expressions null is returned. +#' +#' @family normal_funcs +#' @rdname when +#' @name when +#' @seealso \link{ifelse} +#' @export +#' @examples \dontrun{when(df$age == 2, df$age + 1)} +setMethod("when", signature(condition = "Column", value = "ANY"), + function(condition, value) { + condition <- condition@jc + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) + column(jc) + }) + +#' ifelse +#' +#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' Otherwise \code{no} is returned for unmatched conditions. +#' +#' @family normal_funcs +#' @rdname ifelse +#' @name ifelse +#' @seealso \link{when} +#' @export +#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} +setMethod("ifelse", + signature(test = "Column", yes = "ANY", no = "ANY"), + function(test, yes, no) { + test <- test@jc + yes <- ifelse(class(yes) == "Column", yes@jc, yes) + no <- ifelse(class(no) == "Column", no@jc, no) + jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", + "when", + test, yes), + "otherwise", no) + column(jc) + }) + +###################### Window functions###################### + +#' cume_dist +#' +#' Window function: returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row. +#' +#' N = total number of rows in the partition +#' cume_dist(x) = number of values before (and including) x / N +#' +#' This is equivalent to the CUME_DIST function in SQL. +#' +#' @rdname cume_dist +#' @name cume_dist +#' @family window_funcs +#' @export +#' @examples \dontrun{cume_dist()} +setMethod("cume_dist", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") + column(jc) + }) + +#' dense_rank +#' +#' Window function: returns the rank of rows within a window partition, without any gaps. +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank +#' and had three people tie for second place, you would say that all three were in second +#' place and that the next person came in third. +#' +#' This is equivalent to the DENSE_RANK function in SQL. +#' +#' @rdname dense_rank +#' @name dense_rank +#' @family window_funcs +#' @export +#' @examples \dontrun{dense_rank()} +setMethod("dense_rank", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") + column(jc) + }) + +#' lag +#' +#' Window function: returns the value that is `offset` rows before the current row, and +#' `defaultValue` if there is less than `offset` rows before the current row. For example, +#' an `offset` of one will return the previous row at any given point in the window partition. +#' +#' This is equivalent to the LAG function in SQL. +#' +#' @rdname lag +#' @name lag +#' @family window_funcs +#' @export +#' @examples \dontrun{lag(df$c)} +setMethod("lag", + signature(x = "characterOrColumn"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lag", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' lead +#' +#' Window function: returns the value that is `offset` rows after the current row, and +#' `null` if there is less than `offset` rows after the current row. For example, +#' an `offset` of one will return the next row at any given point in the window partition. +#' +#' This is equivalent to the LEAD function in SQL. +#' +#' @rdname lead +#' @name lead +#' @family window_funcs +#' @export +#' @examples \dontrun{lead(df$c)} +setMethod("lead", + signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lead", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' ntile +#' +#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window +#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second +#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. +#' +#' This is equivalent to the NTILE function in SQL. +#' +#' @rdname ntile +#' @name ntile +#' @family window_funcs +#' @export +#' @examples \dontrun{ntile(1)} +setMethod("ntile", + signature(x = "numeric"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ntile", as.integer(x)) + column(jc) + }) + +#' percent_rank +#' +#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. +#' +#' This is computed by: +#' +#' (rank of row in its partition - 1) / (number of rows in the partition - 1) +#' +#' This is equivalent to the PERCENT_RANK function in SQL. +#' +#' @rdname percent_rank +#' @name percent_rank +#' @family window_funcs +#' @export +#' @examples \dontrun{percent_rank()} +setMethod("percent_rank", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") + column(jc) + }) + +#' rank +#' +#' Window function: returns the rank of rows within a window partition. +#' +#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' and had three people tie for second place, you would say that all three were in second +#' place and that the next person came in third. +#' +#' This is equivalent to the RANK function in SQL. +#' +#' @rdname rank +#' @name rank +#' @family window_funcs +#' @export +#' @examples \dontrun{rank()} +setMethod("rank", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "rank") + column(jc) + }) + +# Expose rank() in the R base package +setMethod("rank", + signature(x = "ANY"), + function(x, ...) { + base::rank(x, ...) + }) + +#' row_number +#' +#' Window function: returns a sequential number starting at 1 within a window partition. +#' +#' This is equivalent to the ROW_NUMBER function in SQL. +#' +#' @rdname row_number +#' @name row_number +#' @family window_funcs +#' @export +#' @examples \dontrun{row_number()} +setMethod("row_number", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "row_number") + column(jc) + }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c43b947129e8..62be2ddc8f52 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -63,6 +63,10 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) +# @rdname statfunctions +# @export +setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -84,12 +88,8 @@ setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) -# @rdname foreach -# @export setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) -# @rdname foreach -# @export setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) # The jrdd accessor function. @@ -103,27 +103,17 @@ setGeneric("glom", function(x) { standardGeneric("glom") }) # @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) -# @rdname lapplyPartition -# @export setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("lapplyPartitionsWithIndex", function(X, FUN) { standardGeneric("lapplyPartitionsWithIndex") }) -# @rdname lapply -# @export setGeneric("map", function(X, FUN) { standardGeneric("map") }) -# @rdname lapplyPartition -# @export setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) @@ -143,7 +133,11 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname numPartitions +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + +# @rdname getNumPartitions # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) @@ -395,11 +389,35 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname columns +#' @export +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) + +#' @rdname columns +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) -#' @rdname describe +#' @rdname statfunctions +#' @export +setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) + +#' @rdname statfunctions +#' @export +setGeneric("corr", function(x, ...) {standardGeneric("corr") }) + +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -413,7 +431,7 @@ setGeneric("dropna", #' @rdname nafunctions #' @export setGeneric("na.omit", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + function(object, ...) { standardGeneric("na.omit") }) @@ -441,7 +459,7 @@ setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) -#' @rdname DataFrame +#' @rdname groupBy #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) @@ -461,13 +479,13 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export -setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) +setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname arrange #' @export @@ -477,7 +495,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -497,9 +515,9 @@ setGeneric("sample", setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) -#' @rdname saveAsParquetFile +#' @rdname statfunctions #' @export -setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable #' @export @@ -507,6 +525,10 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) +#' @rdname withColumn +#' @export +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) @@ -515,6 +537,18 @@ setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @export setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) +#' @rdname write.json +#' @export +setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) + +#' @rdname write.parquet +#' @export +setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) + +#' @rdname write.parquet +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + #' @rdname schema #' @export setGeneric("schema", function(x) { standardGeneric("schema") }) @@ -531,23 +565,23 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +# @rdname subset +# @export +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) + #' @rdname agg #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) -##' rdname summary -##' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +#' @rdname summary +#' @export +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -# @rdname tojson -# @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) -#' @rdname DataFrame -#' @export setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -559,7 +593,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) @@ -567,18 +601,10 @@ setGeneric("withColumnRenamed", ###################### Column Methods ########################## -#' @rdname column -#' @export -setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) - #' @rdname column #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column -#' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) - #' @rdname column #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) @@ -587,16 +613,9 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) -#' @rdname column -#' @export -setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column #' @export @@ -616,7 +635,7 @@ setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname column #' @export -setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname column #' @export @@ -628,56 +647,481 @@ setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname column #' @export -setGeneric("last", function(x) { standardGeneric("last") }) +setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname column #' @export -setGeneric("like", function(x, ...) { standardGeneric("like") }) +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname column #' @export -setGeneric("lower", function(x) { standardGeneric("lower") }) +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @rdname column #' @export -setGeneric("n", function(x) { standardGeneric("n") }) +setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname column #' @export +setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) + + +###################### Expression Function Methods ########################## + +#' @rdname add_months +#' @export +setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) + +#' @rdname approxCountDistinct +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + +#' @rdname ascii +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname avg +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname base64 +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname bin +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname bitwiseNOT +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname cbrt +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname ceil +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname col +#' @export +setGeneric("column", function(x) { standardGeneric("column") }) + +#' @rdname concat +#' @export +setGeneric("concat", function(x, ...) { standardGeneric("concat") }) + +#' @rdname concat_ws +#' @export +setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) + +#' @rdname conv +#' @export +setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) + +#' @rdname countDistinct +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname crc32 +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname cume_dist +#' @export +setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) + +#' @rdname datediff +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname date_add +#' @export +setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) + +#' @rdname date_format +#' @export +setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) + +#' @rdname date_sub +#' @export +setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) + +#' @rdname dayofmonth +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname dayofyear +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + +#' @rdname dense_rank +#' @export +setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) + +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) + +#' @rdname explode +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname expr +#' @export +setGeneric("expr", function(x) { standardGeneric("expr") }) + +#' @rdname from_utc_timestamp +#' @export +setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) + +#' @rdname format_number +#' @export +setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) + +#' @rdname format_string +#' @export +setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) + +#' @rdname from_unixtime +#' @export +setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) + +#' @rdname greatest +#' @export +setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) + +#' @rdname hex +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname hour +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname hypot +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + +#' @rdname initcap +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname instr +#' @export +setGeneric("instr", function(y, x) { standardGeneric("instr") }) + +#' @rdname is.nan +#' @export +setGeneric("isnan", function(x) { standardGeneric("isnan") }) + +#' @rdname kurtosis +#' @export +setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) + +#' @rdname lag +#' @export +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) + +#' @rdname last +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname last_day +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname lead +#' @export +setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) + +#' @rdname least +#' @export +setGeneric("least", function(x, ...) { standardGeneric("least") }) + +#' @rdname levenshtein +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname lit +#' @export +setGeneric("lit", function(x) { standardGeneric("lit") }) + +#' @rdname locate +#' @export +setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) + +#' @rdname lower +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname lpad +#' @export +setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) + +#' @rdname ltrim +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname md5 +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname minute +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname month +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname months_between +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname count +#' @export +setGeneric("n", function(x) { standardGeneric("n") }) + +#' @rdname nanvl +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname negate +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname next_day +#' @export +setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) + +#' @rdname ntile +#' @export +setGeneric("ntile", function(x) { standardGeneric("ntile") }) + +#' @rdname countDistinct +#' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname column +#' @rdname percent_rank +#' @export +setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) + +#' @rdname pmod +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname quarter +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname rand +#' @export +setGeneric("rand", function(seed) { standardGeneric("rand") }) + +#' @rdname randn +#' @export +setGeneric("randn", function(seed) { standardGeneric("randn") }) + +#' @rdname rank +#' @export +setGeneric("rank", function(x, ...) { standardGeneric("rank") }) + +#' @rdname regexp_extract +#' @export +setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) + +#' @rdname regexp_replace +#' @export +setGeneric("regexp_replace", + function(x, pattern, replacement) { standardGeneric("regexp_replace") }) + +#' @rdname reverse +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname rint #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname column +#' @rdname row_number #' @export -setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) +setGeneric("row_number", function(x) { standardGeneric("row_number") }) -#' @rdname column +#' @rdname rpad #' @export -setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) +setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname column +#' @rdname rtrim +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname sd +#' @export +setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) + +#' @rdname second +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname sha1 +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname sha2 +#' @export +setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) + +#' @rdname shiftLeft +#' @export +setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) + +#' @rdname shiftRight +#' @export +setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) + +#' @rdname shiftRightUnsigned +#' @export +setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) + +#' @rdname signum +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname size +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname skewness +#' @export +setGeneric("skewness", function(x) { standardGeneric("skewness") }) + +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + +#' @rdname soundex +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname sd +#' @export +setGeneric("stddev", function(x) { standardGeneric("stddev") }) + +#' @rdname stddev_pop +#' @export +setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) + +#' @rdname stddev_samp +#' @export +setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) + +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + +#' @rdname substring_index +#' @export +setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) + +#' @rdname sumDistinct #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname column +#' @rdname toDegrees #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname toRadians #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname to_date +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname to_utc_timestamp +#' @export +setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) + +#' @rdname translate +#' @export +setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) + +#' @rdname trim +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname unbase64 +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname unhex +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname unix_timestamp +#' @export +setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) + +#' @rdname upper #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname var +#' @export +setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) + +#' @rdname var +#' @export +setGeneric("variance", function(x) { standardGeneric("variance") }) + +#' @rdname var_pop +#' @export +setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) + +#' @rdname var_samp +#' @export +setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) + +#' @rdname weekofyear +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname year +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + #' @rdname glm #' @export setGeneric("glm") +#' @rdname predict +#' @export +setGeneric("predict", function(object, ...) { standardGeneric("predict") }) + #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname as.data.frame +#' @export +setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") + +#' @rdname with +#' @export +setGeneric("with") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 576ac72f40fc..23b49aebda05 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,11 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg +#' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -102,29 +105,32 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { agg(x, ...) }) -# sum/mean/avg/min/max -methods <- c("sum", "mean", "avg", "min", "max") +# Aggregate Functions by name +methods <- c("avg", "max", "mean", "min", "sum") + +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", +# "variance", "var_samp", "var_pop" createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd8749..8d3b4388ae57 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,11 +27,17 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '+', '-', and '.'. +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -41,14 +47,17 @@ setClass("PipelineModel", representation(model = "jobj")) #' sqlContext <- sparkRSQL.init(sc) #' data(iris) #' df <- createDataFrame(sqlContext, iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") +#' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, + standardize = TRUE, solver = "auto") { family <- match.arg(family) + formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha) + "fitRModelFormula", formula, data@sdf, family, lambda, + alpha, standardize, solver) return(new("PipelineModel", model = model)) }) @@ -56,10 +65,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). #' -#' @param model A fitted MLlib model +#' @param object A fitted MLlib model #' @param newData DataFrame for testing #' @return DataFrame containing predicted values -#' @rdname glm +#' @rdname predict #' @export #' @examples #'\dontrun{ @@ -76,24 +85,44 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param model A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. -#' @rdname glm +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr +#' @rdname summary #' @export #' @examples #'\dontrun{ #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { +setMethod("summary", signature(object = "PipelineModel"), + function(object, ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) - weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelWeights", x@model) - coefficients <- as.matrix(unlist(weights)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + "getModelFeatures", object@model) + coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelCoefficients", object@model) + if (modelName == "LinearRegressionModel") { + devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelDevianceResiduals", object@model) + devianceResiduals <- matrix(devianceResiduals, nrow = 1) + colnames(devianceResiduals) <- c("Min", "Max") + rownames(devianceResiduals) <- rep("", times = 1) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) + } else { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + } }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 199c3fd6ab1b..334c11d2f89a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -21,23 +21,24 @@ NULL ############ Actions and Transformations ############ -# Look up elements of a key in an RDD -# -# @description -# \code{lookup} returns a list of values in this RDD for key key. -# -# @param x The RDD to collect -# @param key The key to look up for -# @return a list of values in this RDD for key key -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(c(1, 1), c(2, 2), c(1, 3)) -# rdd <- parallelize(sc, pairs) -# lookup(rdd, 1) # list(1, 3) -#} -# @rdname lookup -# @aliases lookup,RDD-method +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +#' @noRd setMethod("lookup", signature(x = "RDD", key = "ANY"), function(x, key) { @@ -49,21 +50,22 @@ setMethod("lookup", collect(valsRDD) }) -# Count the number of elements for each key, and return the result to the -# master as lists of (key, count) pairs. -# -# Same as countByKey in Spark. -# -# @param x The RDD to count keys. -# @return list of (key, count) pairs, where count is number of each key in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) -# countByKey(rdd) # ("a", 2L), ("b", 1L) -#} -# @rdname countByKey -# @aliases countByKey,RDD-method +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +#' @noRd setMethod("countByKey", signature(x = "RDD"), function(x) { @@ -71,17 +73,18 @@ setMethod("countByKey", countByValue(keys) }) -# Return an RDD with the keys of each tuple. -# -# @param x The RDD from which the keys of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(keys(rdd)) # list(1, 3) -#} -# @rdname keys -# @aliases keys,RDD +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +#' @noRd setMethod("keys", signature(x = "RDD"), function(x) { @@ -91,17 +94,18 @@ setMethod("keys", lapply(x, func) }) -# Return an RDD with the values of each tuple. -# -# @param x The RDD from which the values of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(values(rdd)) # list(2, 4) -#} -# @rdname values -# @aliases values,RDD +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +#' @noRd setMethod("values", signature(x = "RDD"), function(x) { @@ -111,23 +115,24 @@ setMethod("values", lapply(x, func) }) -# Applies a function to all values of the elements, without modifying the keys. -# -# The same as `mapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# makePairs <- lapply(rdd, function(x) { list(x, x) }) -# collect(mapValues(makePairs, function(x) { x * 2) }) -# Output: list(list(1,2), list(2,4), list(3,6), ...) -#} -# @rdname mapValues -# @aliases mapValues,RDD,function-method +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +#' @noRd setMethod("mapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -137,23 +142,24 @@ setMethod("mapValues", lapply(X, func) }) -# Pass each value in the key-value pair RDD through a flatMap function without -# changing the keys; this also retains the original RDD's partitioning. -# -# The same as 'flatMapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) -# collect(flatMapValues(rdd, function(x) { x })) -# Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) -#} -# @rdname flatMapValues -# @aliases flatMapValues,RDD,function-method +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +#' @noRd setMethod("flatMapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -165,38 +171,34 @@ setMethod("flatMapValues", ############ Shuffle Functions ############ -# Partition an RDD by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# For each element of this RDD, the partitioner is used to compute a hash -# function and the RDD is partitioned using this hash value. -# -# @param x The RDD to partition. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @param ... Other optional arguments to partitionBy. -# -# @param partitionFunc The partition function to use. Uses a default hashCode -# function if not provided -# @return An RDD partitioned using the specified partitioner. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- partitionBy(rdd, 2L) -# collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) -#} -# @rdname partitionBy -# @aliases partitionBy,RDD,integer-method +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +#' @noRd setMethod("partitionBy", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { - - #if (missing(partitionFunc)) { - # partitionFunc <- hashCode - #} - partitionFunc <- cleanClosure(partitionFunc) serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) @@ -233,27 +235,28 @@ setMethod("partitionBy", RDD(r, serializedMode = "byte") }) -# Group values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and group values for each key in the RDD into a single sequence. -# -# @param x The RDD to group. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, list(V)) -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- groupByKey(rdd, 2L) -# grouped <- collect(parts) -# grouped[[1]] # Should be a list(1, list(2, 4)) -#} -# @rdname groupByKey -# @aliases groupByKey,RDD,integer-method +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +#' @noRd setMethod("groupByKey", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { @@ -291,28 +294,29 @@ setMethod("groupByKey", lapplyPartition(shuffled, groupVals) }) -# Merge values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, V') where V' is the merged -# value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- reduceByKey(rdd, "+", 2L) -# reduced <- collect(parts) -# reduced[[1]] # Should be a list(1, 6) -#} -# @rdname reduceByKey -# @aliases reduceByKey,RDD,integer-method +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +#' @noRd setMethod("reduceByKey", signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { @@ -332,27 +336,28 @@ setMethod("reduceByKey", lapplyPartition(shuffled, reduceVals) }) -# Merge values by key locally -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function, but return the -# results immediately to the driver as an R list. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @return A list of elements of type list(K, V') where V' is the merged value for each key -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# reduced <- reduceByKeyLocally(rdd, "+") -# reduced # list(list(1, 6), list(1.1, 3)) -#} -# @rdname reduceByKeyLocally -# @aliases reduceByKeyLocally,RDD,integer-method +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +#' @noRd setMethod("reduceByKeyLocally", signature(x = "RDD", combineFunc = "ANY"), function(x, combineFunc) { @@ -384,41 +389,40 @@ setMethod("reduceByKeyLocally", convertEnvsToList(merged[[1]], merged[[2]]) }) -# Combine values by key -# -# Generic function to combine the elements for each key using a custom set of -# aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], -# for a "combined type" C. Note that V and C can be different -- for example, one -# might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). - -# Users provide three functions: -# \itemize{ -# \item createCombiner, which turns a V into a C (e.g., creates a one-element list) -# \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - -# \item mergeCombiners, to combine two C's into a single one (e.g., concatentates -# two lists). -# } -# -# @param x The RDD to combine. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param createCombiner Create a combiner (C) given a value (V) -# @param mergeValue Merge the given value (V) with an existing combiner (C) -# @param mergeCombiners Merge two combiners and return a new combiner -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, C) where C is the combined type -# -# @seealso groupByKey, reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) -# combined <- collect(parts) -# combined[[1]] # Should be a list(1, 6) -#} -# @rdname combineByKey -# @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", mergeCombiners = "ANY", numPartitions = "numeric"), @@ -450,36 +454,37 @@ setMethod("combineByKey", lapplyPartition(shuffled, mergeAfterShuffle) }) -# Aggregate a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using given combine functions -# and a neutral "zero value". This function can return a different result type, -# U, than the type of the values in this RDD, V. Thus, we need one operation -# for merging a V into a U and one operation for merging two U's, The former -# operation is used for merging values within a partition, and the latter is -# used for merging values between partitions. To avoid memory allocation, both -# of these functions are allowed to modify and return their first argument -# instead of creating a new U. -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the values of each key. It may return -# a different result type from the type of the values. -# @param combOp A function to aggregate results of seqOp. -# @return An RDD containing the aggregation result. -# @seealso foldByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) -# # list(list(1, list(3, 2)), list(2, list(7, 2))) -#} -# @rdname aggregateByKey -# @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY", numPartitions = "numeric"), @@ -491,26 +496,27 @@ setMethod("aggregateByKey", combineByKey(x, createCombiner, seqOp, combOp, numPartitions) }) -# Fold a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using an associative function "func" -# and a neutral "zero value" which may be added to the result an arbitrary -# number of times, and must not change the result (e.g., 0 for addition, or -# 1 for multiplication.). -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param func An associative function for folding values of each key. -# @return An RDD containing the aggregation result. -# @seealso aggregateByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) -#} -# @rdname foldByKey -# @aliases foldByKey,RDD,ANY,ANY,integer-method +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +#' @noRd setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", func = "ANY", numPartitions = "numeric"), @@ -520,28 +526,29 @@ setMethod("foldByKey", ############ Binary Functions ############# -# Join two RDDs -# -# @description -# \code{join} This function joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with matching keys in -# two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) -#} -# @rdname join-methods -# @aliases join,RDD,RDD-method +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +#' @noRd setMethod("join", signature(x = "RDD", y = "RDD"), function(x, y, numPartitions) { @@ -556,30 +563,31 @@ setMethod("join", doJoin) }) -# Left outer join two RDDs -# -# @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) -# if no elements in rdd2 have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# leftOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) -#} -# @rdname join-methods -# @aliases leftOuterJoin,RDD,RDD-method +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +#' @noRd setMethod("leftOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -593,30 +601,31 @@ setMethod("leftOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Right outer join two RDDs -# -# @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, w) in y, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) -# if no elements in x have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rightOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) -#} -# @rdname join-methods -# @aliases rightOuterJoin,RDD,RDD-method +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +#' @noRd setMethod("rightOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -630,33 +639,34 @@ setMethod("rightOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Full outer join two RDDs -# -# @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x and (k, w) in y, the resulting RDD -# will contain all pairs (k, (v, w)) for both (k, v) in x and -# (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements -# in x/y have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), -# # list(1, list(3, 1)), -# # list(2, list(NULL, 4))) -# # list(3, list(3, NULL)), -#} -# @rdname join-methods -# @aliases fullOuterJoin,RDD,RDD-method +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +#' @noRd setMethod("fullOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -670,23 +680,24 @@ setMethod("fullOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# For each key k in several RDDs, return a resulting RDD that -# whose values are a list of values for the key in all RDDs. -# -# @param ... Several RDDs. -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with values in a list -# in all RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# cogroup(rdd1, rdd2, numPartitions = 2L) -# # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) -#} -# @rdname cogroup -# @aliases cogroup,RDD-method +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +#' @noRd setMethod("cogroup", "RDD", function(..., numPartitions) { @@ -722,23 +733,24 @@ setMethod("cogroup", group.func) }) -# Sort a (k, v) pair RDD by k. -# -# @param x A (k, v) pair RDD to be sorted. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all (k, v) pair elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) -# collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) -#} -# @rdname sortByKey -# @aliases sortByKey,RDD,RDD-method +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +#' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -784,28 +796,29 @@ setMethod("sortByKey", lapplyPartition(newRDD, partitionFunc) }) -# Subtract a pair RDD with another pair RDD. -# -# Return an RDD with the pairs from x whose keys are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the pairs from x whose keys are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), -# list("b", 5), list("a", 2))) -# rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) -# collect(subtractByKey(rdd1, rdd2)) -# # list(list("b", 4), list("b", 5)) -#} -# @rdname subtractByKey -# @aliases subtractByKey,RDD +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +#' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) @@ -818,41 +831,42 @@ setMethod("subtractByKey", function (v) { v[[1]] }) }) -# Return a subset of this RDD sampled by key. -# -# @description -# \code{sampleByKey} Create a sample of this RDD using variable sampling rates -# for different keys as specified by fractions, a key to sampling rate map. -# -# @param x The RDD to sample elements by key, where each element is -# list(K, V) or c(K, V). -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3000) -# pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) -# else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) -# fractions <- list(a = 0.2, b = 0.1, c = 0.3) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) -# 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE -# 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE -# 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE -# lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE -# lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE -# lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE -# lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE -# lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE -# lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE -# fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored -# fractions <- list(a = 0.2, b = 0.1) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" -#} -# @rdname sampleByKey -# @aliases sampleByKey,RDD-method +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +#' @noRd setMethod("sampleByKey", signature(x = "RDD", withReplacement = "logical", fractions = "vector", seed = "integer"), diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 79c744ef29c2..c6ddb562270b 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -114,6 +114,66 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { + return() + } else { + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.+),(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }, + s = { + # Struct type + m <- regexec("^struct<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + fieldsString <- matchedStrings[[1]][2] + # strsplit does not return the final empty string, so check if + # the final char is "," + if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") { + fields <- strsplit(fieldsString, ",")[[1]] + for (field in fields) { + m <- regexec("^(.+):(.+)$", field) + matchedStrings <- regmatches(field, m) + if (length(matchedStrings[[1]]) >= 3) { + fieldType <- matchedStrings[[1]][3] + checkType(fieldType) + } else { + break + } + } + return() + } + } + }) + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -124,28 +184,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 311021e5d847..17082b4e52fc 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -32,6 +32,21 @@ # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +getSerdeType <- function(object) { + type <- class(object)[[1]] + if (type != "list") { + type + } else { + # Check if all elements are of same type + elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) + if (length(elemType) <= 1) { + "array" + } else { + "list" + } + } +} + writeObject <- function(con, object, writeType = TRUE) { # NOTE: In R vectors have same type as objects. So we don't support # passing in vectors as arrays and instead require arrays to be passed @@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) { type <- "NULL" } } + + serdeType <- getSerdeType(object) if (writeType) { - writeType(con, type) + writeType(con, serdeType) } - switch(type, + switch(serdeType, NULL = writeVoid(con), integer = writeInt(con, object), character = writeString(con, object), @@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) { double = writeDouble(con, object), numeric = writeDouble(con, object), raw = writeRaw(con, object), + array = writeArray(con, object), list = writeList(con, object), + struct = writeList(con, object), jobj = writeJobj(con, object), environment = writeEnv(con, object), Date = writeDate(con, object), @@ -79,7 +98,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big") + writeBin(utfVal, con, endian = "big", useBytes=TRUE) } writeInt <- function(con, value) { @@ -110,18 +129,10 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeRow(rawObj, row) + writeList(rawObj, row) rawConnectionValue(rawObj) } -writeRow <- function(con, row) { - numCols <- length(row) - writeInt(con, numCols) - for (i in 1:numCols) { - writeObject(con, row[[i]]) - } -} - writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") @@ -136,7 +147,9 @@ writeType <- function(con, class) { double = "d", numeric = "d", raw = "r", + array = "a", list = "l", + struct = "s", jobj = "j", environment = "e", Date = "D", @@ -147,15 +160,13 @@ writeType <- function(con, class) { } # Used to pass arrays where all the elements are of the same type -writeList <- function(con, arr) { - # All elements should be of same type - elemType <- unique(sapply(arr, function(elem) { class(elem) })) - stopifnot(length(elemType) <= 1) - +writeArray <- function(con, arr) { # TODO: Empty lists are given type "character" right now. # This may not work if the Java side expects array of any other type. - if (length(elemType) == 0) { + if (length(arr) == 0) { elemType <- class("somestring") + } else { + elemType <- getSerdeType(arr[[1]]) } writeType(con, elemType) @@ -169,7 +180,7 @@ writeList <- function(con, arr) { } # Used to pass arrays where the elements can be of different types -writeGenericList <- function(con, list) { +writeList <- function(con, list) { writeInt(con, length(list)) for (elem in list) { writeObject(con, elem) @@ -182,9 +193,9 @@ writeEnv <- function(con, env) { writeInt(con, len) if (len > 0) { - writeList(con, as.list(ls(env))) + writeArray(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeGenericList(con, as.list(vals)) + writeList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index e83104f11642..d2bfad553104 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -34,11 +34,24 @@ connExists <- function(env) { sparkR.stop <- function() { env <- .sparkREnv if (exists(".sparkRCon", envir = env)) { - # cat("Stopping SparkR\n") if (exists(".sparkRjsc", envir = env)) { sc <- get(".sparkRjsc", envir = env) callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) + + if (exists(".sparkRSQLsc", envir = env)) { + rm(".sparkRSQLsc", envir = env) + } + + if (exists(".sparkRHivesc", envir = env)) { + rm(".sparkRHivesc", envir = env) + } + } + + # Remove the R package lib path from .libPaths() + if (exists(".libPath", envir = env)) { + libPath <- get(".libPath", envir = env) + .libPaths(.libPaths()[.libPaths() != libPath]) } if (exists(".backendLaunched", envir = env)) { @@ -69,15 +82,17 @@ sparkR.stop <- function() { #' Initialize a new Spark Context. #' -#' This function initializes a new SparkContext. +#' This function initializes a new SparkContext. For details on how to initialize +#' and use SparkR, refer to SparkR programming guide at +#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' -#' @param master The Spark master URL. +#' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes. -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. -#' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @param sparkEnvir Named list of environment variables to set on worker nodes +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -85,9 +100,11 @@ sparkR.stop <- function() { #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", #' list(spark.executor.memory="1g")) #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", -#' list(spark.executor.memory="1g"), +#' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("jarfile1.jar","jarfile2.jar")) +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1", +#' "com.databricks:spark-csv_2.10:1.3.0")) #'} sparkR.init <- function( @@ -105,27 +122,25 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + jars <- processSparkJars(sparkJars) + packages <- processSparkPackages(sparkPackages) - # Classpath separator is ";" on Windows - # URI needs four /// as from http://stackoverflow.com/a/18522792 - if (.Platform$OS.type == "unix") { - uriSep <- "//" - } else { - uriSep <- "////" - } + sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") if (existingPort != "") { backendPort <- existingPort } else { path <- tempfile(pattern = "backend_port") + submitOps <- getClientModeSparkSubmitOpts( + Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + sparkEnvirMap) launchBackend( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), - packages = sparkPackages) + sparkSubmitOpts = submitOps, + packages = packages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -141,14 +156,20 @@ sparkR.init <- function( f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) + rLibPath <- readString(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || - length(monitorPort) == 0 || monitorPort == 0) { + length(monitorPort) == 0 || monitorPort == 0 || + length(rLibPath) != 1) { stop("JVM failed to launch") } assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) + if (rLibPath != "") { + assign(".libPath", rLibPath, envir = .sparkREnv) + .libPaths(c(rLibPath, .libPaths())) + } } .sparkREnv$backendPort <- backendPort @@ -160,25 +181,23 @@ sparkR.init <- function( }) if (nchar(sparkHome) != 0) { - sparkHome <- normalizePath(sparkHome) + sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkEnvirMap <- new.env() - for (varname in names(sparkEnvir)) { - sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] - } - - sparkExecutorEnvMap <- new.env() - if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) + if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - for (varname in names(sparkExecutorEnv)) { - sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] - } - nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + localJarPaths <- lapply(jars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -193,7 +212,7 @@ sparkR.init <- function( master, appName, as.character(sparkHome), - as.list(localJarPaths), + localJarPaths, sparkEnvirMap, sparkExecutorEnvMap), envir = .sparkREnv @@ -318,3 +337,52 @@ clearJobGroup <- function(sc) { cancelJobGroup <- function(sc, groupId) { callJMethod(sc, "cancelJobGroup", groupId) } + +sparkConfToSubmitOps <- new.env() +sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" +sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" +sparkConfToSubmitOps[["spark.driver.extraJavaOptions"]] <- "--driver-java-options" +sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-path" + +# Utility function that returns Spark Submit arguments as a string +# +# A few Spark Application and Runtime environment properties cannot take effect after driver +# JVM has started, as documented in: +# http://spark.apache.org/docs/latest/configuration.html#application-properties +# When starting SparkR without using spark-submit, for example, from Rstudio, add them to +# spark-submit commandline if not already set in SPARKR_SUBMIT_ARGS so that they can be effective. +getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) { + envirToOps <- lapply(ls(sparkConfToSubmitOps), function(conf) { + opsValue <- sparkEnvirMap[[conf]] + # process only if --option is not already specified + if (!is.null(opsValue) && + nchar(opsValue) > 1 && + !grepl(sparkConfToSubmitOps[[conf]], submitOps)) { + # put "" around value in case it has spaces + paste0(sparkConfToSubmitOps[[conf]], " \"", opsValue, "\" ") + } else { + "" + } + }) + # --option must be before the application class "sparkr-shell" in submitOps + paste0(paste0(envirToOps, collapse = ""), submitOps) +} + +# Utility function that handles sparkJars argument, and normalize paths +processSparkJars <- function(jars) { + splittedJars <- splitString(jars) + if (length(splittedJars) > length(jars)) { + warning("sparkJars as a comma-separated string is deprecated, use character vector instead") + } + normalized <- suppressWarnings(normalizePath(splittedJars)) + normalized +} + +# Utility function that handles sparkPackages argument +processSparkPackages <- function(packages) { + splittedPackages <- splitString(packages) + if (length(splittedPackages) > length(packages)) { + warning("sparkPackages as a comma-separated string is deprecated, use character vector instead") + } + splittedPackages +} diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R new file mode 100644 index 000000000000..d17cce9c756e --- /dev/null +++ b/R/pkg/R/stats.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# stats.R - Statistic functions for DataFrames. + +setOldClass("jobj") + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @name crosstab +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' ct <- crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) + +#' cov +#' +#' Calculate the sample covariance of two numerical columns of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @return the covariance of the two columns. +#' +#' @rdname statfunctions +#' @name cov +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' cov <- cov(df, "title", "gender") +#' } +setMethod("cov", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "cov", col1, col2) + }) + +#' corr +#' +#' Calculates the correlation of two columns of a DataFrame. +#' Currently only supports the Pearson Correlation Coefficient. +#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @param method Optional. A character specifying the method for calculating the correlation. +#' only "pearson" is allowed now. +#' @return The Pearson Correlation Coefficient as a Double. +#' +#' @rdname statfunctions +#' @name corr +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' corr <- corr(df, "title", "gender") +#' corr <- corr(df, "title", "gender", method = "pearson") +#' } +setMethod("corr", + signature(x = "DataFrame"), + function(x, col1, col2, method = "pearson") { + stopifnot(class(col1) == "character" && class(col2) == "character") + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "corr", col1, col2, method) + }) + +#' freqItems +#' +#' Finding frequent items for columns, possibly with false positives. +#' Using the frequent element count algorithm described in +#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' +#' @param x A SparkSQL DataFrame. +#' @param cols A vector column names to search frequent items in. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' Should be greater than 1e-4. Default support = 0.01. +#' @return a local R data.frame with the frequent items in each column +#' +#' @rdname statfunctions +#' @name freqItems +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' fi = freqItems(df, c("title", "gender")) +#' } +setMethod("freqItems", signature(x = "DataFrame", cols = "character"), + function(x, cols, support = 0.01) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) + collect(dataFrame(sct)) + }) + +#' sampleBy +#' +#' Returns a stratified sample without replacement based on the fraction given on each stratum. +#' +#' @param x A SparkSQL DataFrame +#' @param col column that defines strata +#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is +#' not specified, we treat its fraction as zero. +#' @param seed random seed +#' @return A new DataFrame that represents the stratified sample +#' +#' @rdname statfunctions +#' @name sampleBy +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' sample <- sampleBy(df, "key", fractions, 36) +#' } +setMethod("sampleBy", + signature(x = "DataFrame", col = "character", + fractions = "list", seed = "numeric"), + function(x, col, fractions, seed) { + fractionsEnv <- convertNamedListToEnv(fractions) + + statFunctions <- callJMethod(x@sdf, "stat") + # Seed is expected to be Long on Scala side, here convert it to an integer + # due to SerDe limitation now. + sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed)) + dataFrame(sdf) + }) diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 000000000000..1f06af7e904f --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "tinyint" = "integer", + "smallint" = "integer", + "int" = "integer", + "bigint" = "numeric", + "float" = "numeric", + "double" = "numeric", + "decimal" = "numeric", + "string" = "character", + "binary" = "raw", + "boolean" = "logical", + "timestamp" = "POSIXct", + "date" = "Date", + # following types are not SQL types returned by dtypes(). They are listed here for usage + # by checkType() in schema.R. + # TODO: refactor checkType() in schema.R. + "byte" = "integer", + "integer" = "integer" + )) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map" = NA, + "array" = NA, + "struct" = NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToSQLTypes <- as.environment(list( + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", + "character" = "string", + "logical" = "boolean")) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 4f9f4d9cad2a..43105aaa3842 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -314,7 +314,8 @@ convertEnvsToList <- function(keys, vals) { # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { - pairs <- as.list(substitute(list(...)))[-1L] + # Based on http://stackoverflow.com/a/3057419/4577954 + pairs <- list(...) env <- new.env() for (name in names(pairs)) { env[[name]] <- pairs[[name]] @@ -360,16 +361,6 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a # user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. @@ -597,3 +588,56 @@ mergePartitions <- function(rdd, zip) { PipelinedRDD(rdd, partitionFunc) } + +# Convert a named list to struct so that +# SerDe won't confuse between a normal named list and struct +listToStruct <- function(list) { + stopifnot(class(list) == "list") + stopifnot(!is.null(names(list))) + class(list) <- "struct" + list +} + +# Convert a struct to a named list +structToList <- function(struct) { + stopifnot(class(list) == "struct") + + class(struct) <- "list" + struct +} + +# Convert a named list to an environment to be passed to JVM +convertNamedListToEnv <- function(namedList) { + # Make sure each item in the list has a name + names <- names(namedList) + stopifnot( + if (is.null(names)) { + length(namedList) == 0 + } else { + !any(is.na(names)) + }) + + env <- new.env() + for (name in names) { + env[[name]] <- namedList[[name]] + } + env +} + +# Assign a new environment for attach() and with() methods +assignNewEnv <- function(data) { + stopifnot(class(data) == "DataFrame") + cols <- columns(data) + stopifnot(length(cols) > 0) + + env <- new.env() + for (i in 1:length(cols)) { + assign(x = cols[i], value = data[, cols[i]], envir = env) + } + env +} + +# Utility function to split by ',' and whitespace, remove empty tokens +splitString <- function(input) { + Filter(nzchar, unlist(strsplit(input, ",|\\s"))) +} diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 2a8a8213d084..c55fe9ba7af7 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -17,6 +17,7 @@ .First <- function() { packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") - .libPaths(c(packageDir, .libPaths())) + dirs <- strsplit(packageDir, ",")[[1]] + .libPaths(c(dirs, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 7189f1a26093..90a3761e41f8 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -38,7 +38,7 @@ if (nchar(sparkVer) == 0) { cat("\n") } else { - cat(" version ", sparkVer, "\n") + cat(" version ", sparkVer, "\n") } cat(" /_/", "\n") cat("\n") diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R deleted file mode 100644 index 513bbc8e6205..000000000000 --- a/R/pkg/inst/tests/test_context.R +++ /dev/null @@ -1,57 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -context("test functions in sparkR.R") - -test_that("repeatedly starting and stopping SparkR", { - for (i in 1:4) { - sc <- sparkR.init() - rdd <- parallelize(sc, 1:20, 2L) - expect_equal(count(rdd), 20) - sparkR.stop() - } -}) - -test_that("rdd GC across sparkR.stop", { - sparkR.stop() - sc <- sparkR.init() # sc should get id 0 - rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 - rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 - sparkR.stop() - - sc <- sparkR.init() # sc should get id 0 again - - # GC rdd1 before creating rdd3 and rdd2 after - rm(rdd1) - gc() - - rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now - rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now - - rm(rdd2) - gc() - - count(rdd3) - count(rdd4) -}) - -test_that("job group functions can be called", { - sc <- sparkR.init() - setJobGroup(sc, "groupId", "job description", TRUE) - cancelJobGroup(sc, "groupId") - clearJobGroup(sc) -}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R deleted file mode 100644 index f272de78ad4a..000000000000 --- a/R/pkg/inst/tests/test_mllib.R +++ /dev/null @@ -1,61 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib functions") - -# Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) - -test_that("glm and predict", { - training <- createDataFrame(sqlContext, iris) - test <- select(training, "Sepal_Length") - model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") - prediction <- predict(model, test) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") -}) - -test_that("predictions match with native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) - coefs <- as.vector(stats$coefficients) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) - expect_true(all( - as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) -}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R deleted file mode 100644 index 7377fc8f1ca9..000000000000 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ /dev/null @@ -1,1042 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("SparkSQL functions") - -# Utility function for easily checking the values of a StructField -checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { - expect_equal(class(actual), "structField") - expect_equal(actual$name(), expectedName) - expect_equal(actual$dataType.toString(), expectedType) - expect_equal(actual$nullable(), expectedNullable) -} - -# Tests for SparkSQL functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) - -mockLines <- c("{\"name\":\"Michael\"}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"Justin\", \"age\":19}") -jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") -parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") -writeLines(mockLines, jsonPath) - -# For test nafunctions, like dropna(), fillna(),... -mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", - "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", - "{\"name\":\"David\",\"age\":60,\"height\":null}", - "{\"name\":\"Amy\",\"age\":null,\"height\":null}", - "{\"name\":null,\"age\":null,\"height\":null}") -jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") -writeLines(mockLinesNa, jsonPathNa) - -test_that("infer types", { - expect_equal(infer_type(1L), "integer") - expect_equal(infer_type(1.0), "double") - expect_equal(infer_type("abc"), "string") - expect_equal(infer_type(TRUE), "boolean") - expect_equal(infer_type(as.Date("2015-03-11")), "date") - expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - testStruct <- infer_type(list(a = 1L, b = "2")) - expect_equal(class(testStruct), "structType") - checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) - checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) - e <- new.env() - assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) -}) - -test_that("structType and structField", { - testField <- structField("a", "string") - expect_is(testField, "structField") - expect_equal(testField$name(), "a") - expect_true(testField$nullable()) - - testSchema <- structType(testField, structField("b", "integer")) - expect_is(testSchema, "structType") - expect_is(testSchema$fields()[[2]], "structField") - expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") -}) - -test_that("create DataFrame from RDD", { - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(nrow(df), 10) - expect_equal(ncol(df), 2) - expect_equal(dim(df), c(10, 2)) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - df <- createDataFrame(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("_1", "_2")) - - schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlContext, rdd, schema) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - df <- jsonFile(sqlContext, jsonPathNa) - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) - expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - df2 <- createDataFrame(sqlContext, df.toRDD, schema) - expect_equal(columns(df2), c("name", "age", "height")) - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - - localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19, 23, 18), - height=c(164.10, 181.4, 173.7)) - df <- createDataFrame(sqlContext, localDF, schema) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - expect_equal(columns(df), c("name", "age", "height")) - expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) -}) - -test_that("convert NAs to null type in DataFrames", { - rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(is.na(collect(df)[2, "a"])) - expect_equal(collect(df)[2, "b"], 4L) - - l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(df)[2, "x"], 1L) - expect_true(is.na(collect(df)[2, "y"])) - - rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(is.na(collect(df)[2, "a"])) - expect_equal(collect(df)[2, "b"], 4) - - l <- data.frame(x = 1, y = c(1, NA_real_, 3)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(df)[2, "x"], 1) - expect_true(is.na(collect(df)[2, "y"])) - - l <- list("a", "b", NA, "d") - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], "d") - - l <- list("a", "b", NA_character_, "d") - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], "d") - - l <- list(TRUE, FALSE, NA, TRUE) - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], TRUE) -}) - -test_that("toDF", { - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- toDF(rdd, list("a", "b")) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - df <- toDF(rdd) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("_1", "_2")) - - schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE)) - df <- toDF(rdd, schema) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- toDF(rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) -}) - -test_that("create DataFrame from list or data.frame", { - l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlContext, l, c("a", "b")) - expect_equal(columns(df), c("a", "b")) - - l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlContext, l) - expect_equal(columns(df), c("a", "b")) - - a <- 1:3 - b <- c("a", "b", "c") - ldf <- data.frame(a, b) - df <- createDataFrame(sqlContext, ldf) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - expect_equal(count(df), 3) - ldf2 <- collect(df) - expect_equal(ldf$a, ldf2$a) -}) - -test_that("create DataFrame with different data types", { - l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), - f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlContext, list(l)) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), - c("d", "string"), c("e", "date"), c("f", "timestamp"))) - expect_equal(count(df), 1) - expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) -}) - -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { -# e <- new.env() -# assign("n", 3L, envir = e) -# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) -# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), -# c("c", "map"), c("d", "struct"))) -# expect_equal(count(df), 1) -# ldf <- collect(df) -# expect_equal(ldf[1,], l[[1]]) -#}) - -test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) -}) - -test_that("jsonRDD() on a RDD with json string", { - rdd <- parallelize(sc, mockLines) - expect_equal(count(rdd), 3) - df <- jsonRDD(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlContext, rdd2) - expect_is(df, "DataFrame") - expect_equal(count(df), 6) -}) - -test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - cacheTable(sqlContext, "table1") - uncacheTable(sqlContext, "table1") - clearCache(sqlContext) - dropTempTable(sqlContext, "table1") -}) - -test_that("test tableNames and tables", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - expect_equal(length(tableNames(sqlContext)), 1) - df <- tables(sqlContext) - expect_equal(count(df), 1) - dropTempTable(sqlContext, "table1") -}) - -test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_is(newdf, "DataFrame") - expect_equal(count(newdf), 1) - dropTempTable(sqlContext, "table1") -}) - -test_that("insertInto() on a registered table", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlContext, parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") - - registerTempTable(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql(sqlContext, "select * from table1")), 5) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") - dropTempTable(sqlContext, "table1") - - registerTempTable(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql(sqlContext, "select * from table1")), 2) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") - dropTempTable(sqlContext, "table1") -}) - -test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - tabledf <- table(sqlContext, "table1") - expect_is(tabledf, "DataFrame") - expect_equal(count(tabledf), 3) - dropTempTable(sqlContext, "table1") -}) - -test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- toRDD(df) - expect_is(testRDD, "RDD") - expect_equal(count(testRDD), 3) -}) - -test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) - RDD1 <- toRDD(df) - RDD2 <- toRDD(df) - unioned <- unionRDD(RDD1, RDD2) - expect_is(unioned, "RDD") - expect_equal(SparkR:::getSerializedMode(unioned), "byte") - expect_equal(collect(unioned)[[2]]$name, "Andy") -}) - -test_that("union on mixed serialization types correctly returns a byte RRDD", { - # Byte RDD - nums <- 1:10 - rdd <- parallelize(sc, nums, 2L) - - # String RDD - textLines <- c("Michael", - "Andy, 30", - "Justin, 19") - textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") - writeLines(textLines, textPath) - textRDD <- textFile(sc, textPath) - - df <- jsonFile(sqlContext, jsonPath) - dfRDD <- toRDD(df) - - unionByte <- unionRDD(rdd, dfRDD) - expect_is(unionByte, "RDD") - expect_equal(SparkR:::getSerializedMode(unionByte), "byte") - expect_equal(collect(unionByte)[[1]], 1) - expect_equal(collect(unionByte)[[12]]$name, "Andy") - - unionString <- unionRDD(textRDD, dfRDD) - expect_is(unionString, "RDD") - expect_equal(SparkR:::getSerializedMode(unionString), "byte") - expect_equal(collect(unionString)[[1]], "Michael") - expect_equal(collect(unionString)[[5]]$name, "Andy") -}) - -test_that("objectFile() works with row serialization", { - objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlContext, jsonPath) - dfRDD <- toRDD(df) - saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) - objectIn <- objectFile(sc, objectPath) - - expect_is(objectIn, "RDD") - expect_equal(SparkR:::getSerializedMode(objectIn), "byte") - expect_equal(collect(objectIn)[[2]]$age, 30) -}) - -test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- lapply(df, function(row) { - row$newCol <- row$age + 5 - row - }) - expect_is(testRDD, "RDD") - collected <- collect(testRDD) - expect_equal(collected[[1]]$name, "Michael") - expect_equal(collected[[2]]$newCol, 35) -}) - -test_that("collect() returns a data.frame", { - df <- jsonFile(sqlContext, jsonPath) - rdf <- collect(df) - expect_true(is.data.frame(rdf)) - expect_equal(names(rdf)[1], "age") - expect_equal(nrow(rdf), 3) - expect_equal(ncol(rdf), 2) -}) - -test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlContext, jsonPath) - dfLimited <- limit(df, 2) - expect_is(dfLimited, "DataFrame") - expect_equal(count(dfLimited), 2) -}) - -test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlContext, jsonPath) - expect_equal(nrow(collect(df)), nrow(take(df, 10))) - expect_equal(ncol(collect(df)), ncol(take(df, 10))) -}) - -test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- jsonFile(sqlContext, jsonPath) - first <- lapply(df, function(row) { - row$age <- row$age + 5 - row - }) - second <- lapply(first, function(row) { - row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE - row - }) - expect_is(second, "RDD") - expect_equal(count(second), 3) - expect_equal(collect(second)[[2]]$age, 35) - expect_true(collect(second)[[2]]$testCol) - expect_false(collect(second)[[3]]$testCol) -}) - -test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_false(df@env$isCached) - cache(df) - expect_true(df@env$isCached) - - unpersist(df) - expect_false(df@env$isCached) - - persist(df, "MEMORY_AND_DISK") - expect_true(df@env$isCached) - - unpersist(df) - expect_false(df@env$isCached) - - # make sure the data is collectable - expect_true(is.data.frame(collect(df))) -}) - -test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlContext, jsonPath) - testSchema <- schema(df) - expect_equal(length(testSchema$fields()), 2) - expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") - expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") - expect_equal(testSchema$fields()[[1]]$name(), "age") - - testTypes <- dtypes(df) - expect_equal(length(testTypes[[1]]), 2) - expect_equal(testTypes[[1]][1], "age") - - testCols <- columns(df) - expect_equal(length(testCols), 2) - expect_equal(testCols[2], "name") - - testNames <- names(df) - expect_equal(length(testNames), 2) - expect_equal(testNames[2], "name") -}) - -test_that("head() and first() return the correct data", { - df <- jsonFile(sqlContext, jsonPath) - testHead <- head(df) - expect_equal(nrow(testHead), 3) - expect_equal(ncol(testHead), 2) - - testHead2 <- head(df, 2) - expect_equal(nrow(testHead2), 2) - expect_equal(ncol(testHead2), 2) - - testFirst <- first(df) - expect_equal(nrow(testFirst), 1) -}) - -test_that("distinct() and unique on DataFrames", { - lines <- c("{\"name\":\"Michael\"}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"Justin\", \"age\":19}", - "{\"name\":\"Justin\", \"age\":19}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPathWithDup) - - df <- jsonFile(sqlContext, jsonPathWithDup) - uniques <- distinct(df) - expect_is(uniques, "DataFrame") - expect_equal(count(uniques), 3) - - uniques2 <- unique(df) - expect_is(uniques2, "DataFrame") - expect_equal(count(uniques2), 3) -}) - -test_that("sample on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - sampled <- sample(df, FALSE, 1.0) - expect_equal(nrow(collect(sampled)), count(df)) - expect_is(sampled, "DataFrame") - sampled2 <- sample(df, FALSE, 0.1) - expect_true(count(sampled2) < 3) - - # Also test sample_frac - sampled3 <- sample_frac(df, FALSE, 0.1) - expect_true(count(sampled3) < 3) -}) - -test_that("select operators", { - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_is(df$name, "Column") - expect_is(df[[2]], "Column") - expect_is(df[["age"]], "Column") - - expect_is(df[,1], "DataFrame") - expect_equal(columns(df[,1]), c("name")) - expect_equal(columns(df[,"age"]), c("age")) - df2 <- df[,c("age", "name")] - expect_is(df2, "DataFrame") - expect_equal(columns(df2), c("age", "name")) - - df$age2 <- df$age - expect_equal(columns(df), c("name", "age", "age2")) - expect_equal(count(where(df, df$age2 == df$age)), 2) - df$age2 <- df$age * 2 - expect_equal(columns(df), c("name", "age", "age2")) - expect_equal(count(where(df, df$age2 == df$age * 2)), 2) - - df$age2 <- NULL - expect_equal(columns(df), c("name", "age")) - df$age3 <- NULL - expect_equal(columns(df), c("name", "age")) -}) - -test_that("select with column", { - df <- jsonFile(sqlContext, jsonPath) - df1 <- select(df, "name") - expect_equal(columns(df1), c("name")) - expect_equal(count(df1), 3) - - df2 <- select(df, df$age) - expect_equal(columns(df2), c("age")) - expect_equal(count(df2), 3) -}) - -test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - selected <- selectExpr(df, "age * 2") - expect_equal(names(selected), "(age * 2)") - expect_equal(collect(selected), collect(select(df, df$age * 2L))) - - selected2 <- selectExpr(df, "name as newName", "abs(age) as age") - expect_equal(names(selected2), c("newName", "age")) - expect_equal(count(selected2), 3) -}) - -test_that("column calculation", { - df <- jsonFile(sqlContext, jsonPath) - d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_equal(names(d), c("age2")) - df2 <- select(df, lower(df$name), abs(df$age)) - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - -test_that("read.df() from json file", { - df <- read.df(sqlContext, jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - # Check if we can apply a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_is(df1, "DataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Run the same with loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_is(df2, "DataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) -}) - -test_that("write.df() as parquet file", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - -test_that("test HiveContext", { - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - df2 <- sql(hiveCtx, "select * from json") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - saveAsTable(df, "json", "json", "append", path = jsonPath2) - df3 <- sql(hiveCtx, "select * from json") - expect_is(df3, "DataFrame") - expect_equal(count(df3), 6) -}) - -test_that("column operators", { - c <- SparkR:::col("a") - c2 <- (- c + 1 - 2) * 3 / 4.0 - c3 <- (c + c2 - c2) * c2 %% c2 - c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) - c5 <- c2 ^ c3 ^ c4 -}) - -test_that("column functions", { - c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) - - df <- jsonFile(sqlContext, jsonPath) - df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) - expect_equal(collect(df2)[[2, 1]], TRUE) - expect_equal(collect(df2)[[2, 2]], FALSE) - expect_equal(collect(df2)[[3, 1]], FALSE) - expect_equal(collect(df2)[[3, 2]], TRUE) - - df3 <- select(df, between(df$name, c("Apache", "Spark"))) - expect_equal(collect(df3)[[1, 1]], TRUE) - expect_equal(collect(df3)[[2, 1]], FALSE) - expect_equal(collect(df3)[[3, 1]], TRUE) -}) - -test_that("column binary mathfunctions", { - lines <- c("{\"a\":1, \"b\":5}", - "{\"a\":2, \"b\":6}", - "{\"a\":3, \"b\":7}", - "{\"a\":4, \"b\":8}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) - expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) - ## nolint start - expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) - ## nolint end -}) - -test_that("string operators", { - df <- jsonFile(sqlContext, jsonPath) - expect_equal(count(where(df, like(df$name, "A%"))), 1) - expect_equal(count(where(df, startsWith(df$name, "A"))), 1) - expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") - expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") -}) - -test_that("group by", { - df <- jsonFile(sqlContext, jsonPath) - df1 <- agg(df, name = "max", age = "sum") - expect_equal(1, count(df1)) - df1 <- agg(df, age2 = max(df$age)) - expect_equal(1, count(df1)) - expect_equal(columns(df1), c("age2")) - - gd <- groupBy(df, "name") - expect_is(gd, "GroupedData") - df2 <- count(gd) - expect_is(df2, "DataFrame") - expect_equal(3, count(df2)) - - # Also test group_by, summarize, mean - gd1 <- group_by(df, "name") - expect_is(gd1, "GroupedData") - df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_is(df_summarized, "DataFrame") - expect_equal(3, count(df_summarized)) - - df3 <- agg(gd, age = "sum") - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - - df3 <- agg(gd, age = sum(df$age)) - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - expect_equal(columns(df3), c("name", "age")) - - df4 <- sum(gd, "age") - expect_is(df4, "DataFrame") - expect_equal(3, count(df4)) - expect_equal(3, count(mean(gd, "age"))) - expect_equal(3, count(max(gd, "age"))) -}) - -test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - sorted <- arrange(df, df$age) - expect_equal(collect(sorted)[1,2], "Michael") - - sorted2 <- arrange(df, "name") - expect_equal(collect(sorted2)[2,"age"], 19) - - sorted3 <- orderBy(df, asc(df$age)) - expect_true(is.na(first(sorted3)$age)) - expect_equal(collect(sorted3)[2, "age"], 19) - - sorted4 <- orderBy(df, desc(df$name)) - expect_equal(first(sorted4)$name, "Michael") - expect_equal(collect(sorted4)[3,"name"], "Andy") -}) - -test_that("filter() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - filtered <- filter(df, "age > 20") - expect_equal(count(filtered), 1) - expect_equal(collect(filtered)$name, "Andy") - filtered2 <- where(df, df$name != "Michael") - expect_equal(count(filtered2), 2) - expect_equal(collect(filtered2)$age[2], 19) - - # test suites for %in% - filtered3 <- filter(df, "age in (19)") - expect_equal(count(filtered3), 1) - filtered4 <- filter(df, "age in (19, 30)") - expect_equal(count(filtered4), 2) - filtered5 <- where(df, df$age %in% c(19)) - expect_equal(count(filtered5), 1) - filtered6 <- where(df, df$age %in% c(19, 30)) - expect_equal(count(filtered6), 2) -}) - -test_that("join() and merge() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - - mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", - "{\"name\":\"Andy\", \"test\": \"no\"}", - "{\"name\":\"Justin\", \"test\": \"yes\"}", - "{\"name\":\"Bob\", \"test\": \"yes\"}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlContext, jsonPath2) - - joined <- join(df, df2) - expect_equal(names(joined), c("age", "name", "name", "test")) - expect_equal(count(joined), 12) - - joined2 <- join(df, df2, df$name == df2$name) - expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_equal(count(joined2), 3) - - joined3 <- join(df, df2, df$name == df2$name, "right_outer") - expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_equal(count(joined3), 4) - expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) - - joined4 <- select(join(df, df2, df$name == df2$name, "outer"), - alias(df$age + 5, "newAge"), df$name, df2$test) - expect_equal(names(joined4), c("newAge", "name", "test")) - expect_equal(count(joined4), 4) - expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) - - merged <- select(merge(df, df2, df$name == df2$name, "outer"), - alias(df$age + 5, "newAge"), df$name, df2$test) - expect_equal(names(merged), c("newAge", "name", "test")) - expect_equal(count(merged), 4) - expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) -}) - -test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- toJSON(df) - expect_is(testRDD, "RDD") - expect_equal(SparkR:::getSerializedMode(testRDD), "string") - expect_equal(collect(testRDD)[[1]], mockLines[1]) -}) - -test_that("showDF()", { - df <- jsonFile(sqlContext, jsonPath) - s <- capture.output(showDF(df)) - expected <- paste("+----+-------+\n", - "| age| name|\n", - "+----+-------+\n", - "|null|Michael|\n", - "| 30| Andy|\n", - "| 19| Justin|\n", - "+----+-------+\n", sep="") - expect_output(s , expected) -}) - -test_that("isLocal()", { - df <- jsonFile(sqlContext, jsonPath) - expect_false(isLocal(df)) -}) - -test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") - - unioned <- arrange(unionAll(df, df2), df$age) - expect_is(unioned, "DataFrame") - expect_equal(count(unioned), 6) - expect_equal(first(unioned)$name, "Michael") - - unioned2 <- arrange(rbind(unioned, df, df2), df$age) - expect_is(unioned2, "DataFrame") - expect_equal(count(unioned2), 12) - expect_equal(first(unioned2)$name, "Michael") - - excepted <- arrange(except(df, df2), desc(df$age)) - expect_is(unioned, "DataFrame") - expect_equal(count(excepted), 2) - expect_equal(first(excepted)$name, "Justin") - - intersected <- arrange(intersect(df, df2), df$age) - expect_is(unioned, "DataFrame") - expect_equal(count(intersected), 1) - expect_equal(first(intersected)$name, "Andy") -}) - -test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlContext, jsonPath) - newDF <- withColumn(df, "newAge", df$age + 2) - expect_equal(length(columns(newDF)), 3) - expect_equal(columns(newDF)[3], "newAge") - expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) - - newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_equal(length(columns(newDF2)), 2) - expect_equal(columns(newDF2)[1], "newerAge") -}) - -test_that("mutate(), rename() and names()", { - df <- jsonFile(sqlContext, jsonPath) - newDF <- mutate(df, newAge = df$age + 2) - expect_equal(length(columns(newDF)), 3) - expect_equal(columns(newDF)[3], "newAge") - expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) - - newDF2 <- rename(df, newerAge = df$age) - expect_equal(length(columns(newDF2)), 2) - expect_equal(columns(newDF2)[1], "newerAge") - - names(newDF2) <- c("newerName", "evenNewerAge") - expect_equal(length(names(newDF2)), 2) - expect_equal(names(newDF2)[1], "newerName") -}) - -test_that("write.df() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlContext, parquetPath) - expect_is(parquetDF, "DataFrame") - expect_equal(count(df), count(parquetDF)) -}) - -test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df) * 2) -}) - -test_that("describe() and summarize() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - stats <- describe(df, "age") - expect_equal(collect(stats)[1, "summary"], "count") - expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") - stats <- describe(df) - expect_equal(collect(stats)[4, "name"], "Andy") - expect_equal(collect(stats)[5, "age"], "30") - - stats2 <- summary(df) - expect_equal(collect(stats2)[4, "name"], "Andy") - expect_equal(collect(stats2)[5, "age"], "30") -}) - -test_that("dropna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) - rows <- collect(df) - - # drop with columns - - expected <- rows[!is.na(rows$name),] - actual <- collect(dropna(df, cols = "name")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age),] - actual <- collect(dropna(df, cols = "age")) - row.names(expected) <- row.names(actual) - # identical on two dataframes does not work here. Don't know why. - # use identical on all columns as a workaround. - expect_identical(expected$age, actual$age) - expect_identical(expected$height, actual$height) - expect_identical(expected$name, actual$name) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] - actual <- collect(dropna(df, cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df)) - expect_identical(expected, actual) - - # drop with how - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df)) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] - actual <- collect(dropna(df, "all")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df, "any")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] - actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) | !is.na(rows$height),] - actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_identical(expected, actual) - - # drop with threshold - - expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] - actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[as.integer(!is.na(rows$age)) + - as.integer(!is.na(rows$height)) + - as.integer(!is.na(rows$name)) >= 3,] - actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_identical(expected, actual) -}) - -test_that("fillna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) - rows <- collect(df) - - # fill with value - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - expected$height[is.na(expected$height)] <- 50.6 - actual <- collect(fillna(df, 50.6)) - expect_identical(expected, actual) - - expected <- rows - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, "unknown")) - expect_identical(expected, actual) - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - actual <- collect(fillna(df, 50.6, "age")) - expect_identical(expected, actual) - - expected <- rows - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_identical(expected, actual) - - # fill with named list - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - expected$height[is.na(expected$height)] <- 50.6 - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_identical(expected, actual) -}) - -test_that("crosstab() on a DataFrame", { - rdd <- lapply(parallelize(sc, 0:3), function(x) { - list(paste0("a", x %% 3), paste0("b", x %% 2)) - }) - df <- toDF(rdd, list("a", "b")) - ct <- crosstab(df, "a", "b") - ordered <- ct[order(ct$a_b),] - row.names(ordered) <- NULL - expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), - stringsAsFactors = FALSE, row.names = NULL) - expect_identical(expected, ordered) -}) - -test_that("SQL error message is returned from JVM", { - retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table Not Found: blah", retError), TRUE) -}) - -unlink(parquetPath) -unlink(jsonPath) -unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R similarity index 100% rename from R/pkg/inst/tests/jarTest.R rename to R/pkg/inst/tests/testthat/jarTest.R diff --git a/dev/run-tests-codes.sh b/R/pkg/inst/tests/testthat/packageInAJarTest.R similarity index 71% rename from dev/run-tests-codes.sh rename to R/pkg/inst/tests/testthat/packageInAJarTest.R index f4b238e1b78a..207a37a0cb47 100644 --- a/dev/run-tests-codes.sh +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -1,5 +1,3 @@ -#!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -16,14 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # +library(SparkR) +library(sparkPackageTest) + +sc <- sparkR.init() + +run1 <- myfunc(5L) + +run2 <- myfunc(-4L) + +sparkR.stop() + +if(run1 != 6) quit(save = "no", status = 1) -readonly BLOCK_GENERAL=10 -readonly BLOCK_RAT=11 -readonly BLOCK_SCALA_STYLE=12 -readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_DOCUMENTATION=14 -readonly BLOCK_BUILD=15 -readonly BLOCK_MIMA=16 -readonly BLOCK_SPARK_UNIT_TESTS=17 -readonly BLOCK_PYSPARK_UNIT_TESTS=18 -readonly BLOCK_SPARKR_UNIT_TESTS=19 +if(run2 != -3) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R new file mode 100644 index 000000000000..dddce54d7044 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R similarity index 100% rename from R/pkg/inst/tests/test_binaryFile.R rename to R/pkg/inst/tests/testthat/test_binaryFile.R diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R similarity index 100% rename from R/pkg/inst/tests/test_binary_function.R rename to R/pkg/inst/tests/testthat/test_binary_function.R diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R similarity index 100% rename from R/pkg/inst/tests/test_broadcast.R rename to R/pkg/inst/tests/testthat/test_broadcast.R diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/testthat/test_client.R similarity index 76% rename from R/pkg/inst/tests/test_client.R rename to R/pkg/inst/tests/testthat/test_client.R index 8a20991f89af..a0664f32f31c 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -34,3 +34,12 @@ test_that("no package specified doesn't add packages flag", { test_that("multiple packages don't produce a warning", { expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) }) + +test_that("sparkJars sparkPackages as character vectors", { + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", + c("com.databricks:spark-avro_2.10:2.0.1", + "com.databricks:spark-csv_2.10:1.3.0")) + expect_match(args, "--jars one.jar,two.jar,three.jar") + expect_match(args, + "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") +}) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R new file mode 100644 index 000000000000..1707e314beff --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("repeatedly starting and stopping SparkR SQL", { + for (i in 1:4) { + sc <- sparkR.init() + sqlContext <- sparkRSQL.init(sc) + df <- createDataFrame(sqlContext, data.frame(a = 1:20)) + expect_equal(count(df), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) + +test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + e <- new.env() + e[["spark.driver.memory"]] <- "512m" + ops <- getClientModeSparkSubmitOpts("sparkrmain", e) + expect_equal("--driver-memory \"512m\" sparkrmain", ops) + + e[["spark.driver.memory"]] <- "5g" + e[["spark.driver.extraClassPath"]] <- "/opt/class_path" # nolint + e[["spark.driver.extraJavaOptions"]] <- "-XX:+UseCompressedOops -XX:+UseCompressedStrings" + e[["spark.driver.extraLibraryPath"]] <- "/usr/local/hadoop/lib" # nolint + e[["random"]] <- "skipthis" + ops2 <- getClientModeSparkSubmitOpts("sparkr-shell", e) + # nolint start + expect_equal(ops2, paste0("--driver-class-path \"/opt/class_path\" --driver-java-options \"", + "-XX:+UseCompressedOops -XX:+UseCompressedStrings\" --driver-library-path \"", + "/usr/local/hadoop/lib\" --driver-memory \"5g\" sparkr-shell")) + # nolint end + + e[["spark.driver.extraClassPath"]] <- "/" # too short + ops3 <- getClientModeSparkSubmitOpts("--driver-memory 4g sparkr-shell2", e) + # nolint start + expect_equal(ops3, paste0("--driver-java-options \"-XX:+UseCompressedOops ", + "-XX:+UseCompressedStrings\" --driver-library-path \"/usr/local/hadoop/lib\"", + " --driver-memory 4g sparkr-shell2")) + # nolint end +}) + +test_that("sparkJars sparkPackages as comma-separated strings", { + expect_warning(processSparkJars(" a, b ")) + jars <- suppressWarnings(processSparkJars(" a, b ")) + expect_equal(jars, c("a", "b")) + + jars <- suppressWarnings(processSparkJars(" abc ,, def ")) + expect_equal(jars, c("abc", "def")) + + jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) + expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + + p <- processSparkPackages(c("ghi", "lmn")) + expect_equal(p, c("ghi", "lmn")) + + # check normalizePath + f <- dir()[[1]] + expect_that(processSparkJars(f), not(gives_warning())) + expect_match(processSparkJars(f), f) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_includeJAR.R similarity index 94% rename from R/pkg/inst/tests/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_includeJAR.R index cc1faeabffe3..f89aa8e507fd 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_includeJAR.R @@ -20,7 +20,7 @@ runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, args = c(jarPath, scriptPath), diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/test_includePackage.R rename to R/pkg/inst/tests/testthat/test_includePackage.R diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R new file mode 100644 index 000000000000..08099dd96a87 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("glm should work with long formula", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("predictions match with native glm", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("feature interaction vs native glm", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) + coefs <- unlist(stats$coefficients) + devianceResiduals <- unlist(stats$devianceResiduals) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rCoefs <- unlist(rStats$coefficients) + rDevianceResiduals <- c(-0.95096, 0.72918) + + expect_true(all(abs(rCoefs - coefs) < 1e-5)) + expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) +}) + +test_that("summary coefficients match with native glm of family 'binomial'", { + df <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- filter(df, df$Species != "setosa") + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = "binomial")) + coefs <- as.vector(stats$coefficients[,1]) + + rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] + rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit")))) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) +}) + +test_that("summary works on base GLM models", { + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/test_parallelize_collect.R rename to R/pkg/inst/tests/testthat/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R similarity index 99% rename from R/pkg/inst/tests/test_rdd.R rename to R/pkg/inst/tests/testthat/test_rdd.R index 71aed2bb9d6a..7423b4f2bed1 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -28,8 +28,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(numPartitions(rdd), 2) - expect_equal(numPartitions(intRdd), 2) + expect_equal(getNumPartitions(rdd), 2) + expect_equal(getNumPartitions(intRdd), 2) }) test_that("first on RDD", { @@ -304,18 +304,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartition(rdd, 2) - expect_equal(numPartitions(r1), 2L) + expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartition(rdd, 6) - expect_equal(numPartitions(r2), 6L) + expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) - expect_equal(numPartitions(r3), 1L) + expect_equal(getNumPartitions(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R similarity index 100% rename from R/pkg/inst/tests/test_shuffle.R rename to R/pkg/inst/tests/testthat/test_shuffle.R diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R new file mode 100644 index 000000000000..135c7576e529 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -0,0 +1,1767 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + +test_that("infer types and check types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") + expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct") + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") + + expect_equal(infer_type(as.raw(c(1, 2, 3))), "binary") +}) + +test_that("structType and structField", { + testField <- structField("a", "string") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") + expect_equal(count(df), 10) + expect_equal(count(dfAsDF), 10) + expect_equal(nrow(df), 10) + expect_equal(nrow(dfAsDF), 10) + expect_equal(ncol(df), 2) + expect_equal(ncol(dfAsDF), 2) + expect_equal(dim(df), c(10, 2)) + expect_equal(dim(dfAsDF), c(10, 2)) + expect_equal(columns(df), c("a", "b")) + expect_equal(columns(dfAsDF), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlContext, rdd) + dfAsDF <- as.DataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + expect_equal(columns(dfAsDF), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- createDataFrame(sqlContext, rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df <- read.df(sqlContext, jsonPathNa, "json", schema) + df2 <- createDataFrame(sqlContext, toRDD(df), schema) + df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(columns(df2AsDF), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(as.list(collect(where(df2, df2$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), + age=c(19L, 23L, 18L), + height=c(176.5, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(as.list(collect(where(df, df$name == "John"))), + list(name = "John", age = 19L, height = 176.5)) + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + df <- read.df(hiveCtx, jsonPathNa, "json", schema) + invisible(insertInto(df, "people")) + expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + c(16)) + expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + c(176.5)) +}) + +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- toDF(rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlContext, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlContext, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) + iris_collected <- collect(irisdf) + expect_equivalent(iris_collected[,-5], iris[,-5]) + expect_equal(iris_collected$Species, as.character(iris$Species)) + + mtcarsdf <- createDataFrame(sqlContext, mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(sqlContext, list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlContext, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +test_that("create DataFrame with complex types", { + e <- new.env() + assign("n", 3L, envir = e) + + s <- listToStruct(list(a = "aa", b = 3L)) + + l <- list(as.list(1:10), list("a", "b"), e, s) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"), + c("d", "struct"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b", "c", "d")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) + + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) + + s <- ldf$d[[1]] + expect_equal(class(s), "struct") + expect_equal(s$a, "aa") + expect_equal(s$b, 3L) +}) + +test_that("create DataFrame from a data.frame with complex types", { + ldf <- data.frame(row.names = 1:2) + ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) + + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) +}) + +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + +test_that("Collect DataFrame with complex types", { + # ArrayType + df <- read.json(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # StructType + df <- read.json(sqlContext, mapTypeJsonPath) + expect_equal(dtypes(df), list(c("info", "struct"), + c("name", "string"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("info", "name")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "struct") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) +}) + +test_that("read/write json files", { + # Test read.df + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(sqlContext, jsonPath) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") + write.df(df, jsonPath2, "json", mode="overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern="jsonPath3", fileext=".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "DataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "DataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_equal(count(rdd), 3) + df <- jsonRDD(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlContext, rdd2) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") +}) + +test_that("test tableNames and tables", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + expect_equal(length(tableNames(sqlContext)), 1) + df <- tables(sqlContext) + expect_equal(count(df), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") + dropTempTable(sqlContext, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") + dropTempTable(sqlContext, "table1") + + unlink(jsonPath2) + unlink(parquetPath2) +}) + +test_that("table() returns a new DataFrame", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlContext, "table1") + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) + dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") +}) + +test_that("toRDD() returns an RRDD", { + df <- read.json(sqlContext, jsonPath) + testRDD <- toRDD(df) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- read.json(sqlContext, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_is(unioned, "RDD") + expect_equal(getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- read.json(sqlContext, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_is(unionByte, "RDD") + expect_equal(getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_is(unionString, "RDD") + expect_equal(getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- read.json(sqlContext, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_is(objectIn, "RDD") + expect_equal(getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- read.json(sqlContext, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_is(testRDD, "RDD") + collected <- collect(testRDD) + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) +}) + +test_that("collect() returns a data.frame", { + df <- read.json(sqlContext, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) + + # collect() returns data correctly from a DataFrame with 0 row + df0 <- limit(df, 0) + rdf <- collect(df0) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 0) + expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- read.json(sqlContext, jsonPath) + dfLimited <- limit(df, 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- read.json(sqlContext, jsonPath) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) +}) + +test_that("collect() support Unicode characters", { + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"こんにちは\", \"age\":19}", + "{\"name\":\"Xin chào\"}") + + jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("こんにちは")) + expect_equal(rdf$name[4], markUtf8("Xin chào")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + +test_that("multiple pipeline transformations result in an RDD with the correct values", { + df <- read.json(sqlContext, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- read.json(sqlContext, jsonPath) + testSchema <- schema(df) + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") + + testTypes <- dtypes(df) + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") + + testCols <- columns(df) + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") + + testNames <- names(df) + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") +}) + +test_that("names() colnames() set the column names", { + df <- read.json(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") + + # Test base::colnames base::names + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x","Y") + expect_equal(colnames(m2), c("x", "Y")) + + z <- list(a = 1, b = "c", c = 1:3) + expect_equal(names(z)[3], "c") + names(z)[3] <- "c2" + expect_equal(names(z)[3], "c2") +}) + +test_that("head() and first() return the correct data", { + df <- read.json(sqlContext, jsonPath) + testHead <- head(df) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) + + testHead2 <- head(df, 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) + + testFirst <- first(df) + expect_equal(nrow(testFirst), 1) + + # head() and first() return the correct data on + # a DataFrame with 0 row + df0 <- limit(df, 0) + + testHead <- head(df0) + expect_equal(nrow(testHead), 0) + expect_equal(ncol(testHead), 2) + + testFirst <- first(df0) + expect_equal(nrow(testFirst), 0) + expect_equal(ncol(testFirst), 2) +}) + +test_that("distinct() and unique on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- read.json(sqlContext, jsonPathWithDup) + uniques <- distinct(df) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) + + uniques2 <- unique(df) + expect_is(uniques2, "DataFrame") + expect_equal(count(uniques2), 3) +}) + +test_that("sample on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + sampled <- sample(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_is(sampled, "DataFrame") + sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result + expect_true(count(sampled2) < 3) + + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + + # Also test sample_frac + sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result + expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) +}) + +test_that("select operators", { + df <- select(read.json(sqlContext, jsonPath), "name", "age") + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") + + expect_is(df[,1], "DataFrame") + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_is(df2, "DataFrame") + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) +}) + +test_that("select with column", { + df <- read.json(sqlContext, jsonPath) + df1 <- select(df, "name") + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) + + df2 <- select(df, df$age) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) + + df3 <- select(df, lit("x")) + expect_equal(columns(df3), c("x")) + expect_equal(count(df3), 3) + expect_equal(collect(select(df3, "x"))[[1, 1]], "x") + + df4 <- select(df, c("name", "age")) + expect_equal(columns(df4), c("name", "age")) + expect_equal(count(df4), 3) + + expect_error(select(df, c("name", "age"), "name"), + "To select multiple columns, use a character vector or list for col") +}) + +test_that("subsetting", { + # read.json returns columns in random order + df <- select(read.json(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20,] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1,2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) + + df6 <- subset(df, df$age %in% c(30), c(1,2)) + expect_equal(count(df6), 1) + expect_equal(columns(df6), c("name", "age")) + + df7 <- subset(df, select = "name") + expect_equal(count(df7), 3) + expect_equal(columns(df7), c("name")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) +}) + +test_that("selectExpr() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_equal(names(selected), "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_equal(count(selected2), 3) +}) + +test_that("expr() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) +}) + +test_that("column calculation", { + df <- read.json(sqlContext, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_equal(names(d), c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) +}) + +test_that("test HiveContext", { + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + df2 <- sql(hiveCtx, "select * from json") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + df3 <- sql(hiveCtx, "select * from json2") + expect_is(df3, "DataFrame") + expect_equal(count(df3), 3) + + unlink(jsonPath2) +}) + +test_that("column operators", { + c <- column("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 +}) + +test_that("column functions", { + c <- column("a") + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + c12 <- variance(c) + c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c14 <- cume_dist() + ntile(1) + corr(c, c1) + c15 <- dense_rank() + percent_rank() + rank() + row_number() + c16 <- is.nan(c) + isnan(c) + isNaN(c) + + # Test if base::is.nan() is exposed + expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) + + # Test if base::rank() is exposed + expect_equal(class(rank())[[1]], "Column") + expect_equal(rank(1:3), as.numeric(c(1:3))) + + df <- read.json(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) + + df4 <- select(df, countDistinct(df$age, df$name)) + expect_equal(collect(df4)[[1, 1]], 2) + + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) + expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) + + df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大千世界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大千世界")) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01) +}) + +test_that("string operators", { + df <- read.json(sqlContext, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") + expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) +}) + +test_that("greatest() and least() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) + expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) +}) + +test_that("when(), otherwise() and ifelse() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) +}) + +test_that("group by, agg functions", { + df <- read.json(sqlContext, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_equal(1, count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_equal(1, count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_is(gd, "GroupedData") + df2 <- count(gd) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) + + # Also test group_by, summarize, mean + gd1 <- group_by(df, "name") + expect_is(gd1, "GroupedData") + df_summarized <- summarize(gd, mean_age = mean(df$age)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) + + df3 <- agg(gd, age = "stddev") + expect_is(df3, "DataFrame") + df3_local <- collect(df3) + expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2])) + + df4 <- agg(gd, sumAge = sum(df$age)) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(columns(df4), c("name", "sumAge")) + + df5 <- sum(gd, "age") + expect_is(df5, "DataFrame") + expect_equal(3, count(df5)) + + expect_equal(3, count(mean(gd))) + expect_equal(3, count(max(gd))) + expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(1, collect(count(gd))[1, 2]) + + mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"22\"}", + "{\"name\":\"ID2\", \"value\": \"-3\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") + df6 <- agg(gd2, value = "sum") + df6_local <- collect(df6) + expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) + expect_equal(-3, df6_local[df6_local$name == "ID2",][1, 2]) + + df7 <- agg(gd2, value = "stddev") + df7_local <- collect(df7) + expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) + expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2])) + + mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":1}") + jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines3, jsonPath3) + df8 <- read.json(sqlContext, jsonPath3) + gd3 <- groupBy(df8, "name") + gd3_local <- collect(sum(gd3)) + expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) + expect_equal(20, gd3_local[gd3_local$name == "Justin",][1, 2]) + + expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6) + gd3_local <- collect(agg(gd3, var(df8$age))) + expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) + + # Test stats::sd, stats::var are working + expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) + expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("arrange() and orderBy() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + sorted <- arrange(df, df$age) + expect_equal(collect(sorted)[1,2], "Michael") + + sorted2 <- arrange(df, "name", decreasing = FALSE) + expect_equal(collect(sorted2)[2,"age"], 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_equal(collect(sorted3)[2, "age"], 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") + + sorted5 <- arrange(df, "age", "name", decreasing = TRUE) + expect_equal(collect(sorted5)[1,2], "Andy") + + sorted6 <- arrange(df, "age","name", decreasing = c(T, F)) + expect_equal(collect(sorted6)[1,2], "Andy") + + sorted7 <- arrange(df, "name", decreasing = FALSE) + expect_equal(collect(sorted7)[2,"age"], 19) +}) + +test_that("filter() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + filtered <- filter(df, "age > 20") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) +}) + +test_that("join() and merge() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- read.json(sqlContext, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_equal(count(joined2), 3) + + joined3 <- join(df, df2, df$name == df2$name, "rightouter") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_equal(count(joined3), 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_equal(count(joined4), 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + joined5 <- join(df, df2, df$name == df2$name, "leftouter") + expect_equal(names(joined5), c("age", "name", "name", "test")) + expect_equal(count(joined5), 3) + expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1])) + + joined6 <- join(df, df2, df$name == df2$name, "inner") + expect_equal(names(joined6), c("age", "name", "name", "test")) + expect_equal(count(joined6), 3) + + joined7 <- join(df, df2, df$name == df2$name, "leftsemi") + expect_equal(names(joined7), c("age", "name")) + expect_equal(count(joined7), 3) + + joined8 <- join(df, df2, df$name == df2$name, "left_outer") + expect_equal(names(joined8), c("age", "name", "name", "test")) + expect_equal(count(joined8), 3) + expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1])) + + joined9 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined9), c("age", "name", "name", "test")) + expect_equal(count(joined9), 4) + expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) + + merged <- merge(df, df2, by.x = "name", by.y = "name", all.x = TRUE, all.y = TRUE) + expect_equal(count(merged), 4) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19) + + merged <- merge(df, df2, suffixes = c("-X","-Y")) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30) + + merged <- merge(df, df2, by = "name", suffixes = c("-X","-Y"), sort = FALSE) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael") + + merged <- merge(df, df2, by = "name", all = T, sort = T) + expect_equal(count(merged), 4) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$"name_y"))$"name_x"[1], "Andy") + + merged <- merge(df, df2, by = NULL) + expect_equal(count(merged), 12) + expect_equal(names(merged), c("age", "name", "name", "test")) + + mockLines3 <- c("{\"name\":\"Michael\", \"name_y\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") + jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines3, jsonPath3) + df3 <- read.json(sqlContext, jsonPath3) + expect_error(merge(df, df3), + paste("The following column name: name_y occurs more than once in the 'DataFrame'.", + "Please use different suffixes for the intersected columns.", sep = "")) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- read.json(sqlContext, jsonPath) + testRDD <- toJSON(df) + expect_is(testRDD, "RDD") + expect_equal(getSerializedMode(testRDD), "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- read.json(sqlContext, jsonPath) + s <- capture.output(showDF(df)) + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) +}) + +test_that("isLocal()", { + df <- read.json(sqlContext, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + + unioned <- arrange(unionAll(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") + + unioned2 <- arrange(rbind(unioned, df, df2), df$age) + expect_is(unioned2, "DataFrame") + expect_equal(count(unioned2), 12) + expect_equal(first(unioned2)$name, "Michael") + + excepted <- arrange(except(df, df2), desc(df$age)) + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") + + intersected <- arrange(intersect(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) + + unlink(jsonPath2) +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- read.json(sqlContext, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") +}) + +test_that("mutate(), transform(), rename() and names()", { + df <- read.json(sqlContext, jsonPath) + newDF <- mutate(df, newAge = df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + newDF2 <- rename(df, newerAge = df$age) + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") + + names(newDF2) <- c("newerName", "evenNewerAge") + expect_equal(length(names(newDF2)), 2) + expect_equal(names(newDF2)[1], "newerName") + + transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) + expect_equal(length(columns(transformedDF)), 4) + expect_equal(columns(transformedDF)[3], "newAge") + expect_equal(columns(transformedDF)[4], "newAge2") + expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) + + # test if base::transform on local data frames works + # ensure the proper signature is used - otherwise this will fail to run + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + expect_equal(nrow(result), 153) + expect_equal(ncol(result), 2) + detach(airquality) +}) + +test_that("read/write Parquet files", { + df <- read.df(sqlContext, jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode="overwrite") + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) + expect_is(parquetDF2, "DataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) +}) + +test_that("describe() and summarize() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + stats <- describe(df, "age") + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") + stats <- describe(df) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) +}) + +test_that("dropna() and na.omit() on a DataFrame", { + df <- read.json(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = "name")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + actual <- collect(na.omit(df, cols = "age")) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) +}) + +test_that("fillna() on a DataFrame", { + df <- read.json(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_identical(expected, actual) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_identical(expected, actual) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_identical(expected, actual) +}) + +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + +test_that("cov() and corr() on a DataFrame", { + l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) + df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + result <- cov(df, "singles", "doubles") + expect_true(abs(result - 55.0 / 3) < 1e-12) + + result <- corr(df, "singles", "doubles") + expect_true(abs(result - 1.0) < 1e-12) + result <- corr(df, "singles", "doubles", "pearson") + expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) +}) + +test_that("freqItems() on a DataFrame", { + input <- 1:1000 + rdf <- data.frame(numbers = input, letters = as.character(input), + negDoubles = input * -1.0, stringsAsFactors = F) + rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + df <- createDataFrame(sqlContext, rdf) + multiColResults <- freqItems(df, c("numbers", "letters"), support=0.1) + expect_true(1 %in% multiColResults$numbers[[1]]) + expect_true("1" %in% multiColResults$letters[[1]]) + singleColResult <- freqItems(df, "negDoubles", support=0.1) + expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) + + l <- lapply(c(0:99), function(i) { + if (i %% 2 == 0) { list(1L, -1.0) } + else { list(i, i * -1.0) }}) + df <- createDataFrame(sqlContext, l, c("a", "b")) + result <- freqItems(df, c("a", "b"), 0.4) + expect_identical(result[[1]], list(list(1L, 99L))) + expect_identical(result[[2]], list(list(-1, -99))) +}) + +test_that("sampleBy() on a DataFrame", { + l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) + df <- createDataFrame(sqlContext, l, "key") + fractions <- list("0" = 0.1, "1" = 0.2) + sample <- sampleBy(df, "key", fractions, 0) + result <- collect(orderBy(count(groupBy(sample, "key")), "key")) + expect_identical(as.list(result[1, ]), list(key = "0", count = 3)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) +}) + +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table not found: blah", retError), TRUE) +}) + +irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) + +test_that("Method as.data.frame as a synonym for collect()", { + expect_equal(as.data.frame(irisDF), collect(irisDF)) + irisDF2 <- irisDF[irisDF$Species == "setosa", ] + expect_equal(as.data.frame(irisDF2), collect(irisDF2)) +}) + +test_that("attach() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_error(age) + attach(df) + expect_is(age, "DataFrame") + expected_age <- data.frame(age = c(NA, 30, 19)) + expect_equal(head(age), expected_age) + stat <- summary(age) + expect_equal(collect(stat)[5, "age"], "30") + age <- age$age + 1 + expect_is(age, "Column") + rm(age) + stat2 <- summary(age) + expect_equal(collect(stat2)[5, "age"], "30") + detach("df") + stat3 <- summary(df[, "age"]) + expect_equal(collect(stat3)[5, "age"], "30") + expect_error(age) +}) + +test_that("with() on a DataFrame", { + df <- suppressWarnings(createDataFrame(sqlContext, iris)) + expect_error(Sepal_Length) + sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) + expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") + sum2 <- with(df, distinct(Sepal_Length)) + expect_equal(nrow(sum2), 35) +}) + +test_that("Method coltypes() to get and set R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1=c(1,2,3), + c2=c(T,F,T), + c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a"="b", "c"="d", "e"="f"))))) + expect_equal(coltypes(x), "map") + + df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") +}) + +unlink(parquetPath) +unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/testthat/test_take.R similarity index 100% rename from R/pkg/inst/tests/test_take.R rename to R/pkg/inst/tests/testthat/test_take.R diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R similarity index 100% rename from R/pkg/inst/tests/test_textFile.R rename to R/pkg/inst/tests/testthat/test_textFile.R diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R similarity index 100% rename from R/pkg/inst/tests/test_utils.R rename to R/pkg/inst/tests/testthat/test_utils.R diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3584b418a71a..f55beac6c8c0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,10 +18,11 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") -script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") +dirs <- strsplit(rLibDir, ",")[[1]] +script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") # preload SparkR package, speedup worker -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 7e3b5fc403b2..3ae072beca11 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -35,10 +35,11 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require # SparkR namespace -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) @@ -94,7 +95,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() @@ -120,7 +121,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- readLines(inputCon) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 4f8a1ed2d83e..1d04656ac259 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,4 +18,7 @@ library(testthat) library(SparkR) +# Turn all warnings into errors +options("warn" = 2) + test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh index 18a1e13bdc65..e64a4ea94c58 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/README.md b/README.md index 380422ca00db..d5804d1a20b4 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, and Python, and an optimized engine that +high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, @@ -27,6 +27,8 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) +and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). ## Interactive Scala Shell @@ -59,7 +61,7 @@ will run the Pi example locally. You can set the MASTER environment variable when running examples to submit examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +"yarn" to run on YARN, and "local" to run locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -87,12 +89,9 @@ Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at ["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including -building for particular Hive and Hive Thriftserver distributions. See also -["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) -for guidance on building a Spark application that works with a particular -distribution. +building for particular Hive and Hive Thriftserver distributions. ## Configuration -Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc..4b60ee00ffbe 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 711156337b7c..009d4b92f406 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -32,7 +32,7 @@ ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/org/apache/spark/ui/static + ui-resources/org/apache/spark/ui/static **/* @@ -41,7 +41,7 @@ ${project.parent.basedir}/sbin/ - /sbin + sbin **/* @@ -50,7 +50,7 @@ ${project.parent.basedir}/bin/ - /bin + bin **/* @@ -59,7 +59,7 @@ ${project.parent.basedir}/assembly/target/${spark.jar.dir} - / + ${spark.jar.basename} diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a9..672e9469aec9 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -52,6 +52,10 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index ef0bb2ac13f0..8399033ac61e 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -78,7 +79,7 @@ object Bagel extends Logging { val startTime = System.currentTimeMillis val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( + val combinedMsgs = msgs.combineByKeyWithClassTag( combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) val grouped = combinedMsgs.groupWith(verts) val superstep_ = superstep // Create a read-only copy of superstep for capture in closure @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74..000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} diff --git a/bin/beeline b/bin/beeline index 3fcb6df34339..1627626941a7 100755 --- a/bin/beeline +++ b/bin/beeline @@ -23,8 +23,10 @@ # Enter posix mode for bash set -o posix -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +# Figure out if SPARK_HOME is set +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi CLASS="org.apache.hive.beeline.BeeLine" -exec "$FWDIR/bin/spark-class" $CLASS "$@" +exec "${SPARK_HOME}/bin/spark-class" $CLASS "$@" diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 36d932c453b6..59080edd294f 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\..\conf + set user_conf_dir=%~dp0..\conf ) call :LoadSparkEnv diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 95779e9ddbb1..eaea964ed5b3 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,13 +20,17 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" + +# Figure out where Spark is installed +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 # Returns the parent of the directory this script lives in. - parent_dir="$(cd "`dirname "$0"`"/..; pwd)" + parent_dir="${SPARK_HOME}" user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" @@ -42,18 +46,18 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" + ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.10" - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 - exit 1 - fi + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi - if [ -d "$ASSEMBLY_DIR2" ]; then - export SPARK_SCALA_VERSION="2.11" - else - export SPARK_SCALA_VERSION="2.10" - fi + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi fi diff --git a/bin/pyspark b/bin/pyspark index 8f2a3b5a7717..5eaa17d3c201 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,9 +17,11 @@ # limitations under the License. # -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -source "$SPARK_HOME"/bin/load-spark-env.sh +source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` @@ -64,12 +66,12 @@ fi export PYSPARK_PYTHON # Add the PySpark classes to the Python path: -export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" -export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" -export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" +export PYTHONSTARTUP="${SPARK_HOME}/python/pyspark/shell.py" # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then @@ -82,4 +84,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" +exec "${SPARK_HOME}"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 3c6169983e76..a97d884f0bf3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/bin/run-example b/bin/run-example index 798e2caeb88c..e1b0d5789bed 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,11 +17,13 @@ # limitations under the License. # -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export SPARK_HOME="$FWDIR" -EXAMPLES_DIR="$FWDIR"/examples +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +EXAMPLES_DIR="${SPARK_HOME}"/examples -. "$FWDIR"/bin/load-spark-env.sh +. "${SPARK_HOME}"/bin/load-spark-env.sh if [ -n "$1" ]; then EXAMPLE_CLASS="$1" @@ -34,8 +36,8 @@ else exit 1 fi -if [ -f "$FWDIR/RELEASE" ]; then - JAR_PATH="${FWDIR}/lib" +if [ -f "${SPARK_HOME}/RELEASE" ]; then + JAR_PATH="${SPARK_HOME}/lib" else JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}" fi @@ -44,7 +46,7 @@ JAR_COUNT=0 for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then - echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 + echo "Failed to find Spark examples assembly in ${SPARK_HOME}/lib or ${SPARK_HOME}/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 exit 1 fi @@ -67,7 +69,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -exec "$FWDIR"/bin/spark-submit \ +exec "${SPARK_HOME}"/bin/spark-submit \ --master $EXAMPLE_MASTER \ --class $EXAMPLE_CLASS \ "$SPARK_EXAMPLES_JAR" \ diff --git a/bin/spark-class b/bin/spark-class index 2b59e5df5736..5d964ba96abd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -17,10 +17,11 @@ # limitations under the License. # -# Figure out where Spark is installed -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$SPARK_HOME"/bin/load-spark-env.sh +. "${SPARK_HOME}"/bin/load-spark-env.sh # Find the java binary if [ -n "${JAVA_HOME}" ]; then @@ -36,24 +37,27 @@ fi # Find assembly jar SPARK_ASSEMBLY_JAR= -if [ -f "$SPARK_HOME/RELEASE" ]; then - ASSEMBLY_DIR="$SPARK_HOME/lib" +if [ -f "${SPARK_HOME}/RELEASE" ]; then + ASSEMBLY_DIR="${SPARK_HOME}/lib" else - ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" + ASSEMBLY_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" fi +GREP_OPTIONS= num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 fi -ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 +if [ -d "$ASSEMBLY_DIR" ]; then + ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" + if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 + fi fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" @@ -62,11 +66,17 @@ LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. if [ -n "$SPARK_PREPEND_CLASSES" ]; then - LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" + LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" +# For tests +if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR +fi + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. diff --git a/bin/spark-shell b/bin/spark-shell index 00ab7afd118b..6583b5bd880e 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -28,7 +28,10 @@ esac # Enter posix mode for bash set -o posix -export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, @@ -47,11 +50,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" + "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" + "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-sql b/bin/spark-sql index 4ea7bc6e39c0..970d12cbf51d 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,6 +17,9 @@ # limitations under the License. # -export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" -exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" +exec "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 255378b0f077..023f9c162f4b 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,9 +17,11 @@ # limitations under the License. # -SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" +exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/sparkR b/bin/sparkR index 464c29f36942..2c07a82e2173 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,7 +17,10 @@ # limitations under the License. # -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -source "$SPARK_HOME"/bin/load-spark-env.sh +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" -exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" +exec "${SPARK_HOME}"/bin/spark-submit sparkr-shell-main "$@" diff --git a/build/mvn b/build/mvn index f62f61ee1c41..7603ea03deb7 100755 --- a/build/mvn +++ b/build/mvn @@ -51,11 +51,11 @@ install_app() { # check if we have curl installed # download application [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ - echo "exec: curl ${curl_opts} ${remote_tarball}" && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ - echo "exec: wget ${wget_opts} ${remote_tarball}" && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit [ ! -f "${local_tarball}" ] && \ @@ -82,7 +82,7 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { local zinc_path="zinc-0.3.5.3/bin/zinc" - [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ "http://downloads.typesafe.com/zinc/0.3.5.3" \ "zinc-0.3.5.3.tgz" \ @@ -104,8 +104,8 @@ install_scala() { "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" - SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar" - SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" + SCALA_COMPILER="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-compiler.jar" + SCALA_LIBRARY="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-library.jar" } # Setup healthy defaults for the Zinc port if none were provided from @@ -135,10 +135,10 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it -if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} - ${ZINC_BIN} -shutdown - ${ZINC_BIN} -start -port ${ZINC_PORT} \ + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi @@ -146,7 +146,7 @@ fi # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} -echo "Using \`mvn\` from path: $MVN_BIN" +echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} "$@" +${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" diff --git a/build/sbt b/build/sbt index cc3203d79bcc..7d8d0993e57d 100755 --- a/build/sbt +++ b/build/sbt @@ -20,10 +20,12 @@ # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so # that we can run Hive to generate the golden answer. This is not required for normal development # or testing. -for i in "$HIVE_HOME"/lib/* -do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" -done -export HADOOP_CLASSPATH +if [ -n "$HIVE_HOME" ]; then + for i in "$HIVE_HOME"/lib/* + do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" + done + export HADOOP_CLASSPATH +fi realpath () { ( diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 7930a38b9674..615f84839465 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -38,8 +38,7 @@ dlog () { acquire_sbt_jar () { SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=build/sbt-launch-${SBT_VERSION}.jar sbt_jar=$JAR @@ -51,12 +50,10 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + wget --quiet ${URL1} -O "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 000000000000..9242be3d0357 --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 000000000000..a493ee443c75 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,164 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9..55cb094b4af4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a3598..385b2e772d2c 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 27006e45e932..9809b0c82848 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender @@ -5,11 +22,18 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f..d6962e0da2f3 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20..be42a638230b 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e136..19cba6e71ed1 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 192d3ae09113..771251f90ee3 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. @@ -19,10 +36,10 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) -# - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). -# - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) +# - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). +# - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. @@ -38,6 +55,7 @@ # - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") +# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") diff --git a/core/pom.xml b/core/pom.xml index 202678779150..61744bb5c7bf 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -46,30 +46,14 @@ com.twitter chill_${scala.binary.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - + + + org.apache.xbean + xbean-asm5-shaded org.apache.hadoop @@ -286,7 +270,7 @@ org.tachyonproject tachyon-client - 0.7.0 + 0.8.2 org.apache.hadoop @@ -294,15 +278,19 @@ org.apache.curator - curator-recipes + curator-client - org.tachyonproject - tachyon-underfs-glusterfs + org.apache.curator + curator-framework + + + org.apache.curator + curator-recipes org.tachyonproject - tachyon-underfs-s3 + tachyon-underfs-glusterfs @@ -343,16 +331,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test @@ -361,7 +339,7 @@ net.razorvine pyrolite - 4.4 + 4.9 net.razorvine @@ -372,7 +350,11 @@ net.sf.py4j py4j - 0.8.2.1 + 0.9 + + + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b8..23bc9a2e8172 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba606..e6b24afd88ad 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index 2090efd3b999..d4c42b38ac22 100644 --- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -23,11 +23,13 @@ // See // http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html abstract class JavaSparkContextVarargsWorkaround { - public JavaRDD union(JavaRDD... rdds) { + + @SafeVarargs + public final JavaRDD union(JavaRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -38,18 +40,19 @@ public JavaDoubleRDD union(JavaDoubleRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList rest = new ArrayList(rdds.length - 1); + List rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } return union(rdds[0], rest); } - public JavaPairRDD union(JavaPairRDD... rdds) { + @SafeVarargs + public final JavaPairRDD union(JavaPairRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -57,7 +60,7 @@ public JavaPairRDD union(JavaPairRDD... rdds) { } // These methods take separate "first" and "rest" elements to avoid having the same type erasure - abstract public JavaRDD union(JavaRDD first, List> rest); - abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest); - abstract public JavaPairRDD union(JavaPairRDD first, List> rest); + public abstract JavaRDD union(JavaRDD first, List> rest); + public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest); + public abstract JavaPairRDD union(JavaPairRDD first, List> rest); } diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 000000000000..279639af5d43 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction extends Serializable { + Iterable call(K key, Iterator left, Iterator right) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java new file mode 100644 index 000000000000..e8d999dd0013 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's filter function. + * + * If the function returns true, the element is discarded in the returned Dataset. + */ +public interface FilterFunction extends Serializable { + boolean call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd43631..ef0d1824121e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,5 +23,5 @@ * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterable call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f535ff..14a98a38ef5a 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,5 +23,5 @@ * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - public Iterable call(T1 t1, T2 t2) throws Exception; + Iterable call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java new file mode 100644 index 000000000000..d7a80e7b129b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupsFunction extends Serializable { + Iterable call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java new file mode 100644 index 000000000000..07e54b28fa12 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's foreach function. + * + * Spark will invoke the call function on each element in the input Dataset. + */ +public interface ForeachFunction extends Serializable { + void call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java new file mode 100644 index 000000000000..4938a51bcd71 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a function used in Dataset's foreachPartition function. + */ +public interface ForeachPartitionFunction extends Serializable { + void call(Iterator t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java index d00551bb0add..b9d9777a7565 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java @@ -25,5 +25,5 @@ * when mapping RDDs of other types. */ public interface Function extends Serializable { - public R call(T1 v1) throws Exception; + R call(T1 v1) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index 38e410c5debe..c86928dd0540 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -23,5 +23,5 @@ * A zero-argument function that returns an R. */ public interface Function0 extends Serializable { - public R call() throws Exception; + R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 000000000000..9c35a22ca9d0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. + */ +public interface Function4 extends Serializable { + R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java new file mode 100644 index 000000000000..3ae6ef44898e --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a map function used in Dataset's map function. + */ +public interface MapFunction extends Serializable { + U call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java new file mode 100644 index 000000000000..faa59eabc8b4 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's mapGroup function. + */ +public interface MapGroupsFunction extends Serializable { + R call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java new file mode 100644 index 000000000000..6cb569ce0cb6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for function used in Dataset's mapPartitions. + */ +public interface MapPartitionsFunction extends Serializable { + Iterable call(Iterator input) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java new file mode 100644 index 000000000000..ee092d0058f4 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for function used in Dataset's reduce. + */ +public interface ReduceFunction extends Serializable { + T call(T v1, T v2) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2a10435b7523..f30d42ee5796 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -23,5 +23,5 @@ * A function with no return value. */ public interface VoidFunction extends Serializable { - public void call(T t) throws Exception; + void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 000000000000..da9ae1c9c5cd --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + void call(T1 v1, T2 v2) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 000000000000..36138cc9a297 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An memory consumer of TaskMemoryManager, which support spilling. + * + * Note: this only supports allocation / spilling of Tungsten memory. + */ +public abstract class MemoryConsumer { + + protected final TaskMemoryManager taskMemoryManager; + private final long pageSize; + protected long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes()); + } + + /** + * Returns the size of used memory in bytes. + */ + long getUsed() { + return used; + } + + /** + * Force spill during building. + * + * For testing. + */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + * + * This should be implemented by subclass. + * + * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + * Note: today, this only frees Tungsten-managed pages. + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + * @throws IOException + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + /** + * Allocates a LongArray of `size`. + */ + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } + used += required; + return new LongArray(page); + } + + /** + * Frees a LongArray. + */ + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * Throws IOException if there is not enough memory. + * + * @throws OutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } + used += page.size(); + return page; + } + + /** + * Free a memory block. + */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } +} diff --git a/core/src/main/java/org/apache/spark/memory/MemoryMode.java b/core/src/main/java/org/apache/spark/memory/MemoryMode.java new file mode 100644 index 000000000000..3a5e72d8aaec --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryMode.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.apache.spark.annotation.Private; + +@Private +public enum MemoryMode { + ON_HEAP, + OFF_HEAP +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java similarity index 54% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java rename to core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 358bb3725015..d31eb449eb82 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -15,14 +15,21 @@ * limitations under the License. */ -package org.apache.spark.unsafe.memory; +package org.apache.spark.memory; -import java.util.*; +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashSet; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; + /** * Manages the memory allocated by an individual task. *

@@ -60,7 +67,7 @@ public class TaskMemoryManager { /** * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is - * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page * size is limited by the maximum amount of data that can be stored in a long[] array, which is * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. */ @@ -87,111 +94,208 @@ public class TaskMemoryManager { */ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); - /** - * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean - * up leaked memory. - */ - private final HashSet allocatedNonPageMemory = new HashSet(); + private final MemoryManager memoryManager; - private final ExecutorMemoryManager executorMemoryManager; + private final long taskAttemptId; /** * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. */ - private final boolean inHeap; + final MemoryMode tungstenMemoryMode; + + /** + * Tracks spillable memory consumers. + */ + @GuardedBy("this") + private final HashSet consumers; + + /** + * Construct a new TaskMemoryManager. + */ + public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); + this.memoryManager = memoryManager; + this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); + } + + /** + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call + * spill() of consumers to release more memory. + * + * @return number of bytes successfully granted (<= N). + */ + public long acquireExecutionMemory( + long required, + MemoryMode mode, + MemoryConsumer consumer) { + assert(required >= 0); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); + + // Try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + for (MemoryConsumer c: consumers) { + if (c != consumer && c.getUsed() > 0) { + try { + long released = c.spill(required - got, consumer); + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; + } + } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); + } + } + } + } + + // call spill() on itself + if (got < required && consumer != null) { + try { + long released = consumer.spill(required - got, consumer); + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from itself ({})", taskAttemptId, + Utils.bytesToString(released), consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + } + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " + + e.getMessage()); + } + } + + if (consumer != null) { + consumers.add(consumer); + } + logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } + } + + /** + * Release N bytes of execution memory for a MemoryConsumer. + */ + public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); + memoryManager.releaseExecutionMemory(size, taskAttemptId, mode); + } /** - * Construct a new MemoryManager. + * Dump the memory usage of all consumers. */ - public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { - this.inHeap = executorMemoryManager.inHeap; - this.executorMemoryManager = executorMemoryManager; + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + long memoryAccountedForByConsumers = 0; + for (MemoryConsumer c: consumers) { + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); + } + } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, taskAttemptId); + logger.info( + "{} bytes of memory are used for execution and {} bytes of memory are used for storage", + memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed()); + } + } + + /** + * Return the page size in bytes. + */ + public long pageSizeBytes() { + return memoryManager.pageSizeBytes(); } /** * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is - * intended for allocating large blocks of memory that will be shared between operators. + * intended for allocating large blocks of Tungsten memory that will be shared between operators. + * + * Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. */ - public MemoryBlock allocatePage(long size) { + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } + long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer); + if (acquired <= 0) { + return null; + } + final int pageNumber; synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { + releaseExecutionMemory(acquired, tungstenMemoryMode, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } - final MemoryBlock page = executorMemoryManager.allocate(size); + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); } return page; } /** - * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ - public void freePage(MemoryBlock page) { + public void freePage(MemoryBlock page, MemoryConsumer consumer) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - executorMemoryManager.free(page); + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { allocatedPages.clear(page.pageNumber); } - pageTable[page.pageNumber] = null; if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended - * to be used for allocating operators' internal data structures. For data pages that you want to - * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since - * that will enable intra-memory pointers (see - * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's - * top-level Javadoc for more details). - */ - public MemoryBlock allocate(long size) throws OutOfMemoryError { - assert(size > 0) : "Size must be positive, but got " + size; - final MemoryBlock memory = executorMemoryManager.allocate(size); - allocatedNonPageMemory.add(memory); - return memory; - } - - /** - * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. - */ - public void free(MemoryBlock memory) { - assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; - executorMemoryManager.free(memory); - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + long pageSize = page.size(); + memoryManager.tungstenMemoryAllocator().free(page); + releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer); } /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * - * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (!inHeap) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { // In off-heap mode, an offset is an absolute address that may require a full 64 bits to // encode. Due to our page size limitation, though, we can convert this into an offset that's // relative to the page's base offset; this relative offset will fit in 51 bits. @@ -220,12 +324,13 @@ private static long decodeOffset(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final Object page = pageTable[pageNumber].getBaseObject(); + final MemoryBlock page = pageTable[pageNumber]; assert (page != null); - return page; + assert (page.getBaseObject() != null); + return page.getBaseObject(); } else { return null; } @@ -237,14 +342,16 @@ public Object getPage(long pagePlusOffsetAddress) { */ public long getOffsetInPage(long pagePlusOffsetAddress) { final long offsetInPage = decodeOffset(pagePlusOffsetAddress); - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { return offsetInPage; } else { // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we // converted the absolute address into a relative address. Here, we invert that operation: final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - return pageTable[pageNumber].getBaseOffset() + offsetInPage; + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; } } @@ -253,22 +360,31 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { * value can be used to detect memory leaks. */ public long cleanUpAllAllocatedMemory() { - long freedBytes = 0; + synchronized (this) { + Arrays.fill(pageTable, null); + for (MemoryConsumer c: consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } + } + consumers.clear(); + } + for (MemoryBlock page : pageTable) { if (page != null) { - freedBytes += page.size(); - freePage(page); + memoryManager.tungstenMemoryAllocator().free(page); } } - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); - } - return freedBytes; + Arrays.fill(pageTable, null); + + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); + } + + /** + * Returns the memory consumption, in bytes, for the current task. + */ + public long getMemoryConsumptionForThisTask() { + return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); } } diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 0399abc63c23..0e58bb4f7101 100644 --- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -25,7 +25,7 @@ import scala.reflect.ClassTag; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -49,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -64,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0b8b604e1849..a1a1fb01426a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,21 +21,30 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.annotation.Nullable; +import scala.None$; +import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -62,7 +71,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final BlockManager blockManager; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; private final Serializer serializer; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; public BypassMergeSortShuffleWriter( - SparkConf conf, BlockManager blockManager, - Partitioner partitioner, - ShuffleWriteMetrics writeMetrics, - Serializer serializer) { + IndexShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); - this.numPartitions = partitioner.numPartitions(); this.blockManager = blockManager; - this.partitioner = partitioner; - this.writeMetrics = writeMetrics; - this.serializer = serializer; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.serializer = Serializer.getSerializer(dep.serializer()); + this.shuffleBlockResolver = shuffleBlockResolver; } @Override - public void insertAll(Iterator> records) throws IOException { + public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -124,13 +154,25 @@ public void insertAll(Iterator> records) throws IOException { for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } + + File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + File tmp = Utils.tempFileWith(output); + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - @Override - public long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException { + @VisibleForTesting + long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { @@ -151,7 +193,7 @@ public long[] writePartitionedFile( } finally { Closeables.close(in, copyThrewException); } - if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + if (!partitionWriters[i].fileSegment().file().delete()) { logger.error("Unable to delete file for partition {}", i); } } @@ -165,19 +207,33 @@ public long[] writePartitionedFile( } @Override - public void stop() throws IOException { - if (partitionWriters != null) { - try { - final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (DiskBlockObjectWriter writer : partitionWriters) { - // This method explicitly does _not_ throw exceptions: - writer.revertPartialWritesAndClose(); - if (!diskBlockManager.getFile(writer.blockId()).delete()) { - logger.error("Error while deleting file for block {}", writer.blockId()); + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; } } - } finally { - partitionWriters = null; + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return None$.empty(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java similarity index 96% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index 4ee6a82c0423..f8f2b220e181 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; + +import org.apache.spark.memory.TaskMemoryManager; /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. @@ -26,7 +28,7 @@ * * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the - * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * 13-bit page numbers assigned by {@link TaskMemoryManager}), this * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. *

* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java similarity index 62% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1aa6ba420126..9affff80143d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; +import javax.annotation.Nullable; import java.io.File; import java.io.IOException; import java.util.LinkedList; @@ -30,13 +31,16 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -44,7 +48,7 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then * written to a single output file (or multiple files, if we've spilled). The format of the output * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are @@ -55,24 +59,22 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class UnsafeShuffleExternalSorter { +final class ShuffleExternalSorter extends MemoryConsumer { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private final int initialSize; private final int numPartitions; - private final int pageSizeBytes; - @VisibleForTesting - final int maxRecordSizeBytes; - private final TaskMemoryManager memoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; + private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; + /** Force this sorter to spill when there are this many elements in memory. For testing only */ + private final long numElementsForSpillThreshold; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -86,50 +88,35 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + // These variables are reset after spilling: - private UnsafeShuffleInMemorySorter sorter; - private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - private long freeSpaceInCurrentPage = 0; + @Nullable private ShuffleInMemorySorter inMemSorter; + @Nullable private MemoryBlock currentPage = null; + private long pageCursor = -1; - public UnsafeShuffleExternalSorter( + public ShuffleExternalSorter( TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) throws IOException { - this.memoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; + ShuffleWriteMetrics writeMetrics) { + super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + memoryManager.pageSizeBytes())); + this.taskMemoryManager = memoryManager; this.blockManager = blockManager; this.taskContext = taskContext; - this.initialSize = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, - conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); - this.maxRecordSizeBytes = pageSizeBytes - 4; + this.numElementsForSpillThreshold = + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - initializeForWriting(); - } - - /** - * Allocates new sort data structures. Called when creating the sorter and after each spill. - */ - private void initializeForWriting() throws IOException { - // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryAcquired != memoryRequested) { - shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); - } - - this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); + this.peakMemoryUsedBytes = getMemoryUsage(); } /** @@ -155,8 +142,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = - sorter.getSortedIterator(); + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = + inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. @@ -202,18 +189,14 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = memoryManager.getPage(recordPointer); - final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -232,6 +215,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } + inMemSorter.reset(); + if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -256,8 +241,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { /** * Sort and spill the current records in response to memory pressure. */ - @VisibleForTesting - void spill() throws IOException { + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) { + return 0L; + } + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -265,13 +254,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long sorterMemoryUsage = sorter.getMemoryUsage(); - sorter = null; - shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - - initializeForWriting(); + return spillSize; } private long getMemoryUsage() { @@ -279,131 +264,124 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return sorter.getMemoryUsage() + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } private long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - memoryManager.freePage(block); - shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); + freePage(block); } allocatedPages.clear(); currentPage = null; - currentPagePosition = -1; - freeSpaceInCurrentPage = 0; + pageCursor = 0; return memoryFreed; } /** * Force all memory and spill files to be deleted; called by shuffle error-handling code. */ - public void cleanupAfterError() { + public void cleanupResources() { freeMemory(); + if (inMemSorter != null) { + inMemSorter.free(); + inMemSorter = null; + } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (sorter != null) { - shuffleMemoryManager.release(sorter.getMemoryUsage()); - sorter = null; - } } /** - * Checks whether there is enough space to insert a new record into the sorter. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - - * @return true if the record can be inserted without requiring more allocations, false otherwise. + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. */ - private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); - return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + long used = inMemSorter.getMemoryUsage(); + LongArray array; + try { + // could trigger spilling + array = allocateArray(used / 8 * 2); + } catch (OutOfMemoryError e) { + // should have trigger spilling + assert(inMemSorter.hasSpaceForAnotherRecord()); + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + freeArray(array); + } else { + inMemSorter.expandPointerArray(array); + } + } } /** * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. + * memory from the memory manager and spill if the requested memory can not be obtained. * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. + * @param required the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). */ - private void allocateSpaceForRecord(int requiredSpace) throws IOException { - if (!sorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); - final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); - if (memoryAcquired < memoryToGrowPointerArray) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - } else { - sorter.expandPointerArray(); - shuffleMemoryManager.release(oldPointerArrayMemoryUsage); - } - } - if (requiredSpace > freeSpaceInCurrentPage) { - logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, - freeSpaceInCurrentPage); - // TODO: we should track metrics on the amount of space wasted when we roll over to a new page - // without using the free space at the end of the current page. We should also do this for - // BytesToBytesMap. - if (requiredSpace > pageSizeBytes) { - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - pageSizeBytes + ")"); - } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPage = memoryManager.allocatePage(pageSizeBytes); - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); - } + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) { + // TODO: try to find space in previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); } } /** * Write a record to the shuffle sorter. */ - public void insertRecord( - Object recordBaseObject, - long recordBaseOffset, - int lengthInBytes, - int partitionId) throws IOException { - // Need 4 bytes to store the record length. - final int totalSpaceRequired = lengthInBytes + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { + + // for tests + assert(inMemSorter != null); + if (inMemSorter.numRecords() > numElementsForSpillThreshold) { + spill(); } - final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); - currentPagePosition += 4; - freeSpaceInCurrentPage -= 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - currentPagePosition, - lengthInBytes); - currentPagePosition += lengthInBytes; - freeSpaceInCurrentPage -= lengthInBytes; - sorter.insertRecord(recordAddress, partitionId); + growPointerArrayIfNecessary(); + // Need 4 bytes to store the record length. + final int required = length + 4; + acquireNewPageIfNecessary(required); + + assert(currentPage != null); + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, length); + pageCursor += 4; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + inMemSorter.insertRecord(recordAddress, partitionId); } /** @@ -415,14 +393,16 @@ public void insertRecord( */ public SpillInfo[] closeAndGetSpills() throws IOException { try { - if (sorter != null) { + if (inMemSorter != null) { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); + inMemSorter.free(); + inMemSorter = null; } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { - cleanupAfterError(); + cleanupResources(); throw e; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java similarity index 62% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 5bab501da936..58ad88e1ed87 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -15,15 +15,18 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleInMemorySorter { +final class ShuffleInMemorySorter { - private final Sorter sorter; + private final Sorter sorter; private static final class SortComparator implements Comparator { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { @@ -32,38 +35,61 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { } private static final SortComparator SORT_COMPARATOR = new SortComparator(); + private final MemoryConsumer consumer; + /** * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] pointerArray; + private LongArray array; /** * The position in the pointer array where new records can be inserted. */ - private int pointerArrayInsertPosition = 0; + private int pos = 0; - public UnsafeShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + this.consumer = consumer; assert (initialSize > 0); - this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.array = consumer.allocateArray(initialSize); + this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); + } + + public void free() { + if (array != null) { + consumer.freeArray(array); + array = null; + } + } + + public int numRecords() { + return pos; + } + + public void reset() { + pos = 0; } - public void expandPointerArray() { - final long[] oldArray = pointerArray; - // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L + ); + consumer.freeArray(array); + array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 1 < pointerArray.length; + return pos < array.size(); } public long getMemoryUsage() { - return pointerArray.length * 8L; + return array.size() * 8L; } /** @@ -78,28 +104,23 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (pointerArray.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort pointer array has reached maximum size"); - } else { - expandPointerArray(); - } + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - pointerArray[pointerArrayInsertPosition] = - PackedRecordPointer.packPointer(recordPointer, partitionId); - pointerArrayInsertPosition++; + array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); + pos++; } /** * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. */ - public static final class UnsafeShuffleSorterIterator { + public static final class ShuffleSorterIterator { - private final long[] pointerArray; + private final LongArray pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, LongArray pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -109,7 +130,7 @@ public boolean hasNext() { } public void loadNext() { - packedRecordPointer.set(pointerArray[position]); + packedRecordPointer.set(pointerArray.get(position)); position++; } } @@ -117,8 +138,8 @@ public void loadNext() { /** * Return an iterator over record pointers in sorted order. */ - public UnsafeShuffleSorterIterator getSortedIterator() { - sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + public ShuffleSorterIterator getSortedIterator() { + sorter.sort(array, 0, pos, SORT_COMPARATOR); + return new ShuffleSorterIterator(pos, array); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java new file mode 100644 index 000000000000..8f4e3229976d --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.collection.SortDataFormat; + +final class ShuffleSortDataFormat extends SortDataFormat { + + public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); + + private ShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(LongArray data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) { + reuse.set(data.get(pos)); + return reuse; + } + + @Override + public void swap(LongArray data, int pos0, int pos1) { + final long temp = data.get(pos0); + data.set(pos0, data.get(pos1)); + data.set(pos1, temp); + } + + @Override + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos, src.get(srcPos)); + } + + @Override + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8, + length * 8 + ); + } + + @Override + public LongArray allocate(int length) { + // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length])); + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java deleted file mode 100644 index 656ea0401a14..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.File; -import java.io.IOException; - -import scala.Product2; -import scala.collection.Iterator; - -import org.apache.spark.annotation.Private; -import org.apache.spark.TaskContext; -import org.apache.spark.storage.BlockId; - -/** - * Interface for objects that {@link SortShuffleWriter} uses to write its output files. - */ -@Private -public interface SortShuffleFileWriter { - - void insertAll(Iterator> records) throws IOException; - - /** - * Write all the data added into this shuffle sorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException; - - void stop() throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java similarity index 90% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 7bac0dc0bbeb..df9f7b7abe02 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.File; import org.apache.spark.storage.TempShuffleBlockId; /** - * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + * Metadata for a block of data written by {@link ShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java similarity index 88% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d47d6fc9c2ac..744c3008ca50 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; +import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; -import javax.annotation.Nullable; import scala.Option; import scala.Product2; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -37,10 +38,10 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; -import org.apache.spark.io.LZFCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -48,12 +49,11 @@ import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -68,7 +68,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; @@ -78,8 +77,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; - private MapStatus mapStatus = null; - private UnsafeShuffleExternalSorter sorter = null; + @Nullable private MapStatus mapStatus; + @Nullable private ShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { @@ -101,21 +101,19 @@ public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - UnsafeShuffleHandle handle, + SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); @@ -129,9 +127,22 @@ public UnsafeShuffleWriter( open(); } - @VisibleForTesting - public int maxRecordSizeBytes() { - return sorter.maxRecordSizeBytes; + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } /** @@ -139,12 +150,12 @@ public int maxRecordSizeBytes() { */ @VisibleForTesting public void write(Iterator> records) throws IOException { - write(JavaConversions.asScalaIterator(records)); + write(JavaConverters.asScalaIteratorConverter(records).asScala()); } @Override public void write(scala.collection.Iterator> records) throws IOException { - // Keep track of success so we know if we ecountered an exception + // Keep track of success so we know if we encountered an exception // We do this rather than a standard try/catch/re-throw to handle // generic throwables. boolean success = false; @@ -157,7 +168,7 @@ public void write(scala.collection.Iterator> records) throws IOEx } finally { if (sorter != null) { try { - sorter.cleanupAfterError(); + sorter.cleanupResources(); } catch (Exception e) { // Only throw this error if we won't be masking another // error. @@ -174,9 +185,8 @@ public void write(scala.collection.Iterator> records) throws IOEx private void open() throws IOException { assert (sorter == null); - sorter = new UnsafeShuffleExternalSorter( + sorter = new ShuffleExternalSorter( memoryManager, - shuffleMemoryManager, blockManager, taskContext, INITIAL_SORT_BUFFER_SIZE, @@ -189,13 +199,17 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; + final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills); + partitionLengths = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -203,12 +217,13 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -220,7 +235,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting @@ -235,14 +250,13 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills) throws IOException { - final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); - final boolean fastMergeIsSupported = - !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + final boolean fastMergeIsSupported = !compressionEnabled || + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -431,6 +445,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + if (stopping) { return Option.apply(null); } else { @@ -450,7 +472,7 @@ public Option stop(boolean success) { if (sorter != null) { // If sorter is non-null, then this implies that we called stop() in response to an error, // so we need to clean up memory and spill files created by the sorter - sorter.cleanupAfterError(); + sorter.cleanupResources(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java deleted file mode 100644 index a66d74ee4478..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe; - -import org.apache.spark.util.collection.SortDataFormat; - -final class UnsafeShuffleSortDataFormat extends SortDataFormat { - - public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); - - private UnsafeShuffleSortDataFormat() { } - - @Override - public PackedRecordPointer getKey(long[] data, int pos) { - // Since we re-use keys, this method shouldn't be called. - throw new UnsupportedOperationException(); - } - - @Override - public PackedRecordPointer newKey() { - return new PackedRecordPointer(); - } - - @Override - public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.set(data[pos]); - return reuse; - } - - @Override - public void swap(long[] data, int pos0, int pos1) { - final long temp = data[pos0]; - data[pos0] = data[pos1]; - data[pos1] = temp; - } - - @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos] = src[srcPos]; - } - - @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos, dst, dstPos, length); - } - - @Override - public long[] allocate(int length) { - return new long[length]; - } - -} diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 01a66084e918..3387f9a4177c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,25 +17,30 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; -import java.util.List; - -import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.unsafe.*; +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -56,7 +61,7 @@ * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ -public final class BytesToBytesMap { +public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); @@ -64,38 +69,31 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; - /** - * Special record length that is placed after the last record in a data page. - */ - private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; - /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ - private final List dataPages = new LinkedList(); + private final LinkedList dataPages = new LinkedList<>(); /** * The data page that will be used to store keys and values for new hashtable entries. When this * page becomes full, a new page will be allocated and this pointer will change to point to that * new page. */ - private MemoryBlock currentDataPage = null; + private MemoryBlock currentPage = null; /** - * Offset into `currentDataPage` that points to the location where new data can be inserted into + * Offset into `currentPage` that points to the location where new data can be inserted into * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be - * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since - * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array - * entries per key, giving us a maximum capacity of (1 << 29). + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). */ @VisibleForTesting static final int MAX_CAPACITY = (1 << 29); @@ -109,7 +107,7 @@ public final class BytesToBytesMap { * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. */ - private LongArray longArray; + @Nullable private LongArray longArray; // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode // and exploit word-alignment to use fewer bits to hold the address. This might let us store // only one long per map entry, increasing the chance that this array will fit in cache at the @@ -121,10 +119,9 @@ public final class BytesToBytesMap { // absolute memory addresses. /** - * A {@link BitSet} used to track location of the map where the key is set. - * Size of the bitset should be half of the size of the long array. + * Whether or not the longArray can grow. We will not insert more elements if it's false. */ - private BitSet bitset; + private boolean canGrowArray = true; private final double loadFactor; @@ -166,15 +163,22 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; + private long peakMemoryUsedBytes = 0L; + + private final BlockManager blockManager; + private volatile MapIterator destructiveIterator = null; + private LinkedList spillWriters = new LinkedList<>(); + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { + super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -195,21 +199,19 @@ public BytesToBytesMap( public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this( taskMemoryManager, - shuffleMemoryManager, + SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -221,49 +223,159 @@ public BytesToBytesMap( */ public int numElements() { return numElements; } - public static final class BytesToBytesMapIterator implements Iterator { + public final class MapIterator implements Iterator { - private final int numRecords; - private final Iterator dataPagesIterator; + private int numRecords; private final Location loc; - private MemoryBlock currentPage; - private int currentRecordNumber = 0; + private MemoryBlock currentPage = null; + private int recordsInPage = 0; private Object pageBaseObject; private long offsetInPage; - private BytesToBytesMapIterator( - int numRecords, Iterator dataPagesIterator, Location loc) { + // If this iterator destructive or not. When it is true, it frees each page as it moves onto + // next one. + private boolean destructive = false; + private UnsafeSorterSpillReader reader = null; + + private MapIterator(int numRecords, Location loc, boolean destructive) { this.numRecords = numRecords; - this.dataPagesIterator = dataPagesIterator; this.loc = loc; - if (dataPagesIterator.hasNext()) { - advanceToNextPage(); + this.destructive = destructive; + if (destructive) { + destructiveIterator = this; } } private void advanceToNextPage() { - currentPage = dataPagesIterator.next(); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + freePage(currentPage); + nextIdx --; + } + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); + offsetInPage += 4; + } else { + currentPage = null; + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + try { + Closeables.close(reader, /* swallowIOException = */ false); + reader = spillWriters.getFirst().getReader(blockManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + } + } } @Override public boolean hasNext() { - return currentRecordNumber != numRecords; + if (numRecords == 0) { + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + return numRecords > 0; } @Override public Location next() { - int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); - if (totalLength == END_OF_PAGE_MARKER) { + if (recordsInPage == 0) { advanceToNextPage(); - totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); } - loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; - currentRecordNumber++; - return loc; + numRecords--; + if (currentPage != null) { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; + recordsInPage --; + return loc; + } else { + assert(reader != null); + if (!reader.hasNext()) { + advanceToNextPage(); + } + try { + reader.loadNext(); + } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } + // Scala iterator does not handle exception + Platform.throwException(e); + } + loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength()); + return loc; + } + } + + public long spill(long numBytes) throws IOException { + synchronized (this) { + if (!destructive || dataPages.size() == 1) { + return 0L; + } + + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } + + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = Platform.getInt(base, offset); + offset += 4; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = Platform.getInt(base, offset); + writer.write(base, offset + 4, length, 0); + offset += 4 + length; + numRecords--; + } + writer.close(); + spillWriters.add(writer); + + dataPages.removeLast(); + released += block.size(); + freePage(block); + + if (released >= numBytes) { + break; + } + } + + return released; + } } @Override @@ -280,8 +392,22 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc); + public MapIterator iterator() { + return new MapIterator(numElements, loc, false); + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. It frees each page + * as it moves onto next one. Notice: it is illegal to call any method on the map after + * `destructiveIterator()` has been called. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public MapIterator destructiveIterator() { + return new MapIterator(numElements, loc, true); } /** @@ -290,41 +416,51 @@ public BytesToBytesMapIterator iterator() { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes) { + public Location lookup(Object keyBase, long keyOffset, int keyLength) { + safeLookup(keyBase, keyOffset, keyLength, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { + assert(longArray != null); + if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); int pos = hashcode & mask; int step = 1; while (true) { if (enablePerfMetrics) { numProbes++; } - if (!bitset.isSet(pos)) { + if (longArray.get(pos * 2) == 0) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { // Full hash code matches. Let's compare the keys for equality. loc.with(pos, hashcode, true); - if (loc.getKeyLength() == keyRowLengthBytes) { + if (loc.getKeyLength() == keyLength) { final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedKeyBaseObject = keyAddress.getBaseObject(); - final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final Object storedkeyBase = keyAddress.getBaseObject(); + final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( - keyBaseObject, - keyBaseOffset, - storedKeyBaseObject, - storedKeyBaseOffset, - keyRowLengthBytes + keyBase, + keyOffset, + storedkeyBase, + storedkeyOffset, + keyLength ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; @@ -368,21 +504,22 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object page, final long offsetInPage) { - long position = offsetInPage; - final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + private void updateAddressesAndSizes(final Object base, final long offset) { + long position = offset; + final int totalLength = Platform.getInt(base, position); position += 4; - keyLength = PlatformDependent.UNSAFE.getInt(page, position); + keyLength = Platform.getInt(base, position); position += 4; valueLength = totalLength - keyLength - 4; - keyMemoryLocation.setObjAndOffset(page, position); + keyMemoryLocation.setObjAndOffset(base, position); position += keyLength; - valueMemoryLocation.setObjAndOffset(page, position); + valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { + assert(longArray != null); this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; @@ -400,6 +537,19 @@ private Location with(MemoryBlock page, long offsetInPage) { return this; } + /** + * This is only used for spilling + */ + private Location with(Object base, long offset, int length) { + this.isDefined = true; + this.memoryPage = null; + keyLength = Platform.getInt(base, offset); + valueLength = length - 4 - keyLength; + keyMemoryLocation.setObjAndOffset(base, offset + 4); + valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); + return this; + } + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. @@ -474,9 +624,9 @@ public int getValueLength() { * As an example usage, here's the proper way to store a new key: *

*
-     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -488,124 +638,90 @@ public int getValueLength() {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(
-        Object keyBaseObject,
-        long keyBaseOffset,
-        int keyLengthBytes,
-        Object valueBaseObject,
-        long valueBaseOffset,
-        int valueLengthBytes) {
+    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
+        Object valueBase, long valueOffset, int valueLength) {
       assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLengthBytes % 8 == 0);
-      assert (valueLengthBytes % 8 == 0);
-      if (numElements == MAX_CAPACITY) {
-        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      assert (keyLength % 8 == 0);
+      assert (valueLength % 8 == 0);
+      assert(longArray != null);
+
+
+      if (numElements == MAX_CAPACITY
+        // The map could be reused from last spill (because of no enough memory to grow),
+        // then we don't try to grow again if hit the `growthThreshold`.
+        || !canGrowArray && numElements > growthThreshold) {
+        return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (8 byte value length) (value)
-      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
-
-      // --- Figure out where to insert the new record ---------------------------------------------
-
-      final MemoryBlock dataPage;
-      final Object dataPageBaseObject;
-      final long dataPageInsertOffset;
-      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
-      if (useOverflowPage) {
-        // The record is larger than the page size, so allocate a special overflow page just to hold
-        // that record.
-        final long memoryRequested = requiredSize + 8;
-        final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
-        if (memoryGranted != memoryRequested) {
-          shuffleMemoryManager.release(memoryGranted);
-          logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+      // (8 byte key length) (key) (value)
+      final long recordLength = 8 + keyLength + valueLength;
+      if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
+        if (!acquireNewPage(recordLength + 4L)) {
           return false;
         }
-        MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
-        dataPages.add(overflowPage);
-        dataPage = overflowPage;
-        dataPageBaseObject = overflowPage.getBaseObject();
-        dataPageInsertOffset = overflowPage.getBaseOffset();
-      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
-        // The record can fit in a data page, but either we have not allocated any pages yet or
-        // the current page does not have enough space.
-        if (currentDataPage != null) {
-          // There wasn't enough space in the current page, so write an end-of-page marker:
-          final Object pageBaseObject = currentDataPage.getBaseObject();
-          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
-          PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
-        }
-        final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryGranted != pageSizeBytes) {
-          shuffleMemoryManager.release(memoryGranted);
-          logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
-          return false;
-        }
-        MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        dataPages.add(newPage);
-        pageCursor = 0;
-        currentDataPage = newPage;
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset();
-      } else {
-        // There is enough space in the current data page.
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
       }
 
       // --- Append the key and value data to the current data page --------------------------------
-
-      long insertCursor = dataPageInsertOffset;
-
-      // Compute all of our offsets up-front:
-      final long recordOffset = insertCursor;
-      insertCursor += 4;
-      final long keyLengthOffset = insertCursor;
-      insertCursor += 4;
-      final long keyDataOffsetInPage = insertCursor;
-      insertCursor += keyLengthBytes;
-      final long valueDataOffsetInPage = insertCursor;
-      insertCursor += valueLengthBytes; // word used to store the value size
-
-      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
-        keyLengthBytes + valueLengthBytes + 4);
-      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
-      // Copy the key
-      PlatformDependent.copyMemory(
-        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
-      // Copy the value
-      PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
-        valueDataOffsetInPage, valueLengthBytes);
-
-      // --- Update bookeeping data structures -----------------------------------------------------
-
-      if (useOverflowPage) {
-        // Store the end-of-page marker at the end of the data page
-        PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
-      } else {
-        pageCursor += requiredSize;
-      }
-
+      final Object base = currentPage.getBaseObject();
+      long offset = currentPage.getBaseOffset() + pageCursor;
+      final long recordOffset = offset;
+      Platform.putInt(base, offset, keyLength + valueLength + 4);
+      Platform.putInt(base, offset + 4, keyLength);
+      offset += 8;
+      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
+      offset += keyLength;
+      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+
+      // --- Update bookkeeping data structures -----------------------------------------------------
+      offset = currentPage.getBaseOffset();
+      Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+      pageCursor += recordLength;
       numElements++;
-      bitset.set(pos);
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, recordOffset);
+        currentPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
+
       if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        growAndRehash();
+        try {
+          growAndRehash();
+        } catch (OutOfMemoryError oom) {
+          canGrowArray = false;
+        }
       }
       return true;
     }
   }
 
+  /**
+   * Acquire a new page from the memory manager.
+   * @return whether there is enough space to allocate the new page.
+   */
+  private boolean acquireNewPage(long required) {
+    try {
+      currentPage = allocatePage(required);
+    } catch (OutOfMemoryError e) {
+      return false;
+    }
+    dataPages.add(currentPage);
+    Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+    pageCursor = 4;
+    return true;
+  }
+
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this && destructiveIterator != null) {
+      return destructiveIterator.spill(size);
+    }
+    return 0L;
+  }
+
   /**
    * Allocate new data structures for this map. When calling this outside of the constructor,
    * make sure to keep references to the old data structures so that you can free them.
@@ -614,11 +730,10 @@ public boolean putNewKey(
    */
   private void allocate(int capacity) {
     assert (capacity >= 0);
-    // The capacity needs to be divisible by 64 so that our bit set can be sized properly
-    capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+    capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
-    longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
-    bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
+    longArray = allocateArray(capacity * 2);
+    longArray.zeroOut();
 
     this.growthThreshold = (int) (capacity * loadFactor);
     this.mask = capacity - 1;
@@ -631,37 +746,61 @@ private void allocate(int capacity) {
    * This method is idempotent and can be called multiple times.
    */
   public void free() {
-    longArray = null;
-    bitset = null;
+    updatePeakMemoryUsed();
+    if (longArray != null) {
+      freeArray(longArray);
+      longArray = null;
+    }
     Iterator dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
       MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
-      taskMemoryManager.freePage(dataPage);
-      shuffleMemoryManager.release(dataPage.size());
+      freePage(dataPage);
     }
     assert(dataPages.isEmpty());
+
+    while (!spillWriters.isEmpty()) {
+      File file = spillWriters.removeFirst().getFile();
+      if (file != null && file.exists()) {
+        if (!file.delete()) {
+          logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+        }
+      }
+    }
   }
 
   public TaskMemoryManager getTaskMemoryManager() {
     return taskMemoryManager;
   }
 
-  public ShuffleMemoryManager getShuffleMemoryManager() {
-    return shuffleMemoryManager;
-  }
-
   public long getPageSizeBytes() {
     return pageSizeBytes;
   }
 
-  /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+  /**
+   * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+   */
   public long getTotalMemoryConsumption() {
     long totalDataPagesSize = 0L;
     for (MemoryBlock dataPage : dataPages) {
       totalDataPagesSize += dataPage.size();
     }
-    return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
+    return totalDataPagesSize + ((longArray != null) ? longArray.memoryBlock().size() : 0L);
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getTotalMemoryConsumption();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   /**
@@ -674,7 +813,6 @@ public long getTimeSpentResizingNs() {
     return timeSpentResizingNs;
   }
 
-
   /**
    * Returns the average number of probes per key lookup.
    */
@@ -693,58 +831,71 @@ public long getNumHashCollisions() {
   }
 
   @VisibleForTesting
-  int getNumDataPages() {
+  public int getNumDataPages() {
     return dataPages.size();
   }
 
+  /**
+   * Returns the underline long[] of longArray.
+   */
+  public LongArray getArray() {
+    assert(longArray != null);
+    return longArray;
+  }
+
+  /**
+   * Reset this map to initialized state.
+   */
+  public void reset() {
+    numElements = 0;
+    longArray.zeroOut();
+
+    while (dataPages.size() > 0) {
+      MemoryBlock dataPage = dataPages.removeLast();
+      freePage(dataPage);
+    }
+    currentPage = null;
+    pageCursor = 0;
+  }
+
   /**
    * Grows the size of the hash table and re-hash everything.
    */
   @VisibleForTesting
   void growAndRehash() {
+    assert(longArray != null);
+
     long resizeStartTime = -1;
     if (enablePerfMetrics) {
       resizeStartTime = System.nanoTime();
     }
     // Store references to the old data structures to be used when we re-hash
     final LongArray oldLongArray = longArray;
-    final BitSet oldBitSet = bitset;
-    final int oldCapacity = (int) oldBitSet.capacity();
+    final int oldCapacity = (int) oldLongArray.size() / 2;
 
     // Allocate the new data structures
     allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
 
     // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
-    for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
-      final long keyPointer = oldLongArray.get(pos * 2);
-      final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+    for (int i = 0; i < oldLongArray.size(); i += 2) {
+      final long keyPointer = oldLongArray.get(i);
+      if (keyPointer == 0) {
+        continue;
+      }
+      final int hashcode = (int) oldLongArray.get(i + 1);
       int newPos = hashcode & mask;
       int step = 1;
-      boolean keepGoing = true;
-
-      // No need to check for equality here when we insert so this has one less if branch than
-      // the similar code path in addWithoutResize.
-      while (keepGoing) {
-        if (!bitset.isSet(newPos)) {
-          bitset.set(newPos);
-          longArray.set(newPos * 2, keyPointer);
-          longArray.set(newPos * 2 + 1, hashcode);
-          keepGoing = false;
-        } else {
-          newPos = (newPos + step) & mask;
-          step++;
-        }
+      while (longArray.get(newPos * 2) != 0) {
+        newPos = (newPos + step) & mask;
+        step++;
       }
+      longArray.set(newPos * 2, keyPointer);
+      longArray.set(newPos * 2 + 1, hashcode);
     }
+    freeArray(oldLongArray);
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;
     }
   }
-
-  /** Returns the next number greater or equal num that is power of 2. */
-  private static long nextPowerOf2(long num) {
-    final long highBit = Long.highestOneBit(num);
-    return (highBit == num) ? num : highBit << 1;
-  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
index a90cc0e761f6..40b5fb7fe4b4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -15,6 +15,24 @@
  * limitations under the License.
  */
 
+/*
+ * Based on TimSort.java from the Android Open Source Project
+ *
+ *  Copyright (C) 2008 The Android Open Source Project
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ */
+
 package org.apache.spark.util.collection;
 
 import java.util.Comparator;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 4d7e5b3dfba6..d2bf297c6c17 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -20,6 +20,8 @@
 import com.google.common.primitives.UnsignedLongs;
 
 import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.ByteArray;
 import org.apache.spark.unsafe.types.UTF8String;
 import org.apache.spark.util.Utils;
 
@@ -29,6 +31,8 @@ private PrefixComparators() {}
 
   public static final StringPrefixComparator STRING = new StringPrefixComparator();
   public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+  public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
+  public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
   public static final LongPrefixComparator LONG = new LongPrefixComparator();
   public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
   public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
@@ -52,6 +56,24 @@ public int compare(long bPrefix, long aPrefix) {
     }
   }
 
+  public static final class BinaryPrefixComparator extends PrefixComparator {
+    @Override
+    public int compare(long aPrefix, long bPrefix) {
+      return UnsignedLongs.compare(aPrefix, bPrefix);
+    }
+
+    public static long computePrefix(byte[] bytes) {
+      return ByteArray.getPrefix(bytes);
+    }
+  }
+
+  public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+    @Override
+    public int compare(long bPrefix, long aPrefix) {
+      return UnsignedLongs.compare(aPrefix, bPrefix);
+    }
+  }
+
   public static final class LongPrefixComparator extends PrefixComparator {
     @Override
     public int compare(long a, long b) {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
index 0c4ebde407cf..dbf6770e0739 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -17,9 +17,11 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import org.apache.spark.memory.TaskMemoryManager;
+
 final class RecordPointerAndKeyPrefix {
   /**
-   * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+   * A pointer to a record; see {@link TaskMemoryManager} for a
    * description of how these addresses are encoded.
    */
   public long recordPointer;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index b984301cbbf2..79d74b23ceae 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,14 +17,11 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
-
-import javax.annotation.Nullable;
-
-import scala.runtime.AbstractFunction0;
-import scala.runtime.BoxedUnit;
+import java.util.Queue;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
@@ -32,26 +29,25 @@
 
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.Utils;
 
 /**
  * External sorter based on {@link UnsafeInMemorySorter}.
  */
-public final class UnsafeExternalSorter {
+public final class UnsafeExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
-  private final long pageSizeBytes;
   private final PrefixComparator prefixComparator;
   private final RecordComparator recordComparator;
-  private final int initialSize;
   private final TaskMemoryManager taskMemoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private ShuffleWriteMetrics writeMetrics;
@@ -70,17 +66,15 @@ public final class UnsafeExternalSorter {
   private final LinkedList spillWriters = new LinkedList<>();
 
   // These variables are reset after spilling:
-  private UnsafeInMemorySorter inMemSorter;
-  // Whether the in-mem sorter is created internally, or passed in from outside.
-  // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
-  private boolean isInMemSorterExternal = false;
+  @Nullable private volatile UnsafeInMemorySorter inMemSorter;
+
   private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
+  private long pageCursor = -1;
+  private long peakMemoryUsedBytes = 0;
+  private volatile SpillableIterator readingIterator = null;
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
@@ -88,86 +82,67 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
       int initialSize,
       long pageSizeBytes,
       UnsafeInMemorySorter inMemorySorter) throws IOException {
-    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+    UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+    sorter.spill(Long.MAX_VALUE, sorter);
+    // The external sorter will be used to insert records, in-memory sorter is not needed.
+    sorter.inMemSorter = null;
+    return sorter;
   }
 
   public static UnsafeExternalSorter create(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
-      long pageSizeBytes) throws IOException {
-    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+      long pageSizeBytes) {
+    return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
   }
 
   private UnsafeExternalSorter(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+      @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+    super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.recordComparator = recordComparator;
     this.prefixComparator = prefixComparator;
-    this.initialSize = initialSize;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
-    this.pageSizeBytes = pageSizeBytes;
+    // TODO: metrics tracking + integration with shuffle write metrics
+    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
     this.writeMetrics = new ShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
-      initializeForWriting();
+      this.inMemSorter = new UnsafeInMemorySorter(
+        this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
     } else {
-      this.isInMemSorterExternal = true;
       this.inMemSorter = existingInMemorySorter;
     }
+    this.peakMemoryUsedBytes = getMemoryUsage();
 
     // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
     // does not fully consume the sorter's output (e.g. sort followed by limit).
-    taskContext.addOnCompleteCallback(new AbstractFunction0() {
-      @Override
-      public BoxedUnit apply() {
-        deleteSpillFiles();
-        freeMemory();
-        return null;
+    taskContext.addTaskCompletionListener(
+      new TaskCompletionListener() {
+        @Override
+        public void onTaskCompletion(TaskContext context) {
+          cleanupResources();
+        }
       }
-    });
-  }
-
-  // TODO: metrics tracking + integration with shuffle write metrics
-  // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    this.writeMetrics = new ShuffleWriteMetrics();
-    final long pointerArrayMemory =
-      UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize);
-    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory);
-    if (memoryAcquired != pointerArrayMemory) {
-      shuffleMemoryManager.release(memoryAcquired);
-      throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory");
-    }
-
-    this.inMemSorter =
-      new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
-    this.isInMemSorterExternal = false;
+    );
   }
 
   /**
@@ -176,38 +151,59 @@ private void initializeForWriting() throws IOException {
    */
   @VisibleForTesting
   public void closeCurrentPage() {
-    freeSpaceInCurrentPage = 0;
+    if (currentPage != null) {
+      pageCursor = currentPage.getBaseOffset() + currentPage.size();
+    }
   }
 
   /**
    * Sort and spill the current records in response to memory pressure.
    */
-  public void spill() throws IOException {
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this) {
+      if (readingIterator != null) {
+        return readingIterator.spill();
+      }
+      return 0L; // this should throw exception
+    }
+
+    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+      return 0L;
+    }
+
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
       spillWriters.size(),
       spillWriters.size() > 1 ? " times" : " time");
 
-    final UnsafeSorterSpillWriter spillWriter =
-      new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
-        inMemSorter.numRecords());
-    spillWriters.add(spillWriter);
-    final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
-    while (sortedRecords.hasNext()) {
-      sortedRecords.loadNext();
-      final Object baseObject = sortedRecords.getBaseObject();
-      final long baseOffset = sortedRecords.getBaseOffset();
-      final int recordLength = sortedRecords.getRecordLength();
-      spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+    // We only write out contents of the inMemSorter if it is not empty.
+    if (inMemSorter.numRecords() > 0) {
+      final UnsafeSorterSpillWriter spillWriter =
+        new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+          inMemSorter.numRecords());
+      spillWriters.add(spillWriter);
+      final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
+      while (sortedRecords.hasNext()) {
+        sortedRecords.loadNext();
+        final Object baseObject = sortedRecords.getBaseObject();
+        final long baseOffset = sortedRecords.getBaseOffset();
+        final int recordLength = sortedRecords.getRecordLength();
+        spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+      }
+      spillWriter.close();
+
+      inMemSorter.reset();
     }
-    spillWriter.close();
+
     final long spillSize = freeMemory();
     // Note that this is more-or-less going to be a multiple of the page size, so wasted space in
     // pages will currently be counted as memory spilled even though that space isn't actually
     // written to disk. This also counts the space needed to store the sorter's pointer array.
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-    initializeForWriting();
+
+    return spillSize;
   }
 
   /**
@@ -219,7 +215,22 @@ private long getMemoryUsage() {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return inMemSorter.getMemoryUsage() + totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   @VisibleForTesting
@@ -228,142 +239,115 @@ public int getNumberOfAllocatedPages() {
   }
 
   /**
-   * Free this sorter's in-memory data structures, including its data pages and pointer array.
+   * Free this sorter's data pages.
    *
    * @return the number of bytes freed.
    */
-  public long freeMemory() {
+  private long freeMemory() {
+    updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
-      shuffleMemoryManager.release(block.size());
       memoryFreed += block.size();
-    }
-    if (inMemSorter != null) {
-      if (!isInMemSorterExternal) {
-        long sorterMemoryUsage = inMemSorter.getMemoryUsage();
-        memoryFreed += sorterMemoryUsage;
-        shuffleMemoryManager.release(sorterMemoryUsage);
-      }
-      inMemSorter = null;
+      freePage(block);
     }
     allocatedPages.clear();
     currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
+    pageCursor = 0;
     return memoryFreed;
   }
 
   /**
    * Deletes any spill files created by this sorter.
    */
-  public void deleteSpillFiles() {
+  private void deleteSpillFiles() {
     for (UnsafeSorterSpillWriter spill : spillWriters) {
       File file = spill.getFile();
       if (file != null && file.exists()) {
         if (!file.delete()) {
           logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
-        };
+        }
       }
     }
   }
 
   /**
-   * Checks whether there is enough space to insert a new record into the sorter.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
-
-   * @return true if the record can be inserted without requiring more allocations, false otherwise.
+   * Frees this sorter's in-memory data structures and cleans up its spill files.
    */
-  private boolean haveSpaceForRecord(int requiredSpace) {
-    assert (requiredSpace > 0);
-    return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+  public void cleanupResources() {
+    synchronized (this) {
+      deleteSpillFiles();
+      freeMemory();
+      if (inMemSorter != null) {
+        inMemSorter.free();
+        inMemSorter = null;
+      }
+    }
   }
 
   /**
-   * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size.
+   * Checks whether there is enough space to insert an additional record in to the sort pointer
+   * array and grows the array if additional space is required. If the required space cannot be
+   * obtained, then the in-memory data will be spilled to disk.
    */
-  private void allocateSpaceForRecord(int requiredSpace) throws IOException {
-    // TODO: merge these steps to first calculate total memory requirements for this insert,
-    // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
-    // data page.
+  private void growPointerArrayIfNecessary() throws IOException {
+    assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
-      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
-      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
-      if (memoryAcquired < memoryToGrowPointerArray) {
-        shuffleMemoryManager.release(memoryAcquired);
-        spill();
+      long used = inMemSorter.getMemoryUsage();
+      LongArray array;
+      try {
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        assert(inMemSorter.hasSpaceForAnotherRecord());
+        return;
+      }
+      // check if spilling is triggered or not
+      if (inMemSorter.hasSpaceForAnotherRecord()) {
+        freeArray(array);
       } else {
-        inMemSorter.expandPointerArray();
-        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+        inMemSorter.expandPointerArray(array);
       }
     }
+  }
 
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryAcquired < pageSizeBytes) {
-          if (memoryAcquired > 0) {
-            shuffleMemoryManager.release(memoryAcquired);
-          }
-          spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
-            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
-            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-          }
-        }
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = pageSizeBytes;
-        allocatedPages.add(currentPage);
-      }
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the memory manager and spill if the requested memory can not be obtained.
+   *
+   * @param required the required space in the data page, in bytes, including space for storing
+   *                      the record size. This must be less than or equal to the page size (records
+   *                      that exceed the page size are handled via a different code path which uses
+   *                      special overflow pages).
+   */
+  private void acquireNewPageIfNecessary(int required) {
+    if (currentPage == null ||
+      pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
+      // TODO: try to find space on previous pages
+      currentPage = allocatePage(required);
+      pageCursor = currentPage.getBaseOffset();
+      allocatedPages.add(currentPage);
     }
   }
 
   /**
    * Write a record to the sorter.
    */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      long prefix) throws IOException {
-    // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-    if (!haveSpaceForRecord(totalSpaceRequired)) {
-      allocateSpaceForRecord(totalSpaceRequired);
-    }
+  public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+    throws IOException {
 
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
-    final Object dataPageBaseObject = currentPage.getBaseObject();
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
-    currentPagePosition += 4;
-    PlatformDependent.copyMemory(
-      recordBaseObject,
-      recordBaseOffset,
-      dataPageBaseObject,
-      currentPagePosition,
-      lengthInBytes);
-    currentPagePosition += lengthInBytes;
-    freeSpaceInCurrentPage -= totalSpaceRequired;
+    growPointerArrayIfNecessary();
+    // Need 4 bytes to store the record length.
+    final int required = length + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, length);
+    pageCursor += 4;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
+    assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
   }
 
@@ -375,51 +359,232 @@ public void insertRecord(
    *
    * record length = key length + value length + 4
    */
-  public void insertKVRecord(
-      Object keyBaseObj, long keyOffset, int keyLen,
-      Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
-    final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
-    if (!haveSpaceForRecord(totalSpaceRequired)) {
-      allocateSpaceForRecord(totalSpaceRequired);
+  public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
+      Object valueBase, long valueOffset, int valueLen, long prefix)
+    throws IOException {
+
+    growPointerArrayIfNecessary();
+    final int required = keyLen + valueLen + 4 + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
+    pageCursor += 4;
+    Platform.putInt(base, pageCursor, keyLen);
+    pageCursor += 4;
+    Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
+    pageCursor += keyLen;
+    Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
+    pageCursor += valueLen;
+
+    assert(inMemSorter != null);
+    inMemSorter.insertRecord(recordAddress, prefix);
+  }
+
+  /**
+   * Merges another UnsafeExternalSorters into this one, the other one will be emptied.
+   *
+   * @throws IOException
+   */
+  public void merge(UnsafeExternalSorter other) throws IOException {
+    other.spill();
+    spillWriters.addAll(other.spillWriters);
+    // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`.
+    other.spillWriters.clear();
+    other.cleanupResources();
+  }
+
+  /**
+   * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
+   * after consuming this iterator.
+   */
+  public UnsafeSorterIterator getSortedIterator() throws IOException {
+    if (spillWriters.isEmpty()) {
+      assert(inMemSorter != null);
+      readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+      return readingIterator;
+    } else {
+      final UnsafeSorterSpillMerger spillMerger =
+        new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
+      for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+        spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
+      }
+      if (inMemSorter != null) {
+        readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+        spillMerger.addSpillIfNotEmpty(readingIterator);
+      }
+      return spillMerger.getSortedIterator();
+    }
+  }
+
+  /**
+   * An UnsafeSorterIterator that support spilling.
+   */
+  class SpillableIterator extends UnsafeSorterIterator {
+    private UnsafeSorterIterator upstream;
+    private UnsafeSorterIterator nextUpstream = null;
+    private MemoryBlock lastPage = null;
+    private boolean loaded = false;
+    private int numRecords = 0;
+
+    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+      this.upstream = inMemIterator;
+      this.numRecords = inMemIterator.numRecordsLeft();
     }
 
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
-    final Object dataPageBaseObject = currentPage.getBaseObject();
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen + valueLen + 4);
-    currentPagePosition += 4;
+    public long spill() throws IOException {
+      synchronized (this) {
+        if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
+          && numRecords > 0)) {
+          return 0L;
+        }
 
-    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen);
-    currentPagePosition += 4;
+        UnsafeInMemorySorter.SortedIterator inMemIterator =
+          ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+
+        // Iterate over the records that have not been returned and spill them.
+        final UnsafeSorterSpillWriter spillWriter =
+          new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
+        while (inMemIterator.hasNext()) {
+          inMemIterator.loadNext();
+          final Object baseObject = inMemIterator.getBaseObject();
+          final long baseOffset = inMemIterator.getBaseOffset();
+          final int recordLength = inMemIterator.getRecordLength();
+          spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
+        }
+        spillWriter.close();
+        spillWriters.add(spillWriter);
+        nextUpstream = spillWriter.getReader(blockManager);
+
+        long released = 0L;
+        synchronized (UnsafeExternalSorter.this) {
+          // release the pages except the one that is used. There can still be a caller that
+          // is accessing the current record. We free this page in that caller's next loadNext()
+          // call.
+          for (MemoryBlock page : allocatedPages) {
+            if (!loaded || page.getBaseObject() != upstream.getBaseObject()) {
+              released += page.size();
+              freePage(page);
+            } else {
+              lastPage = page;
+            }
+          }
+          allocatedPages.clear();
+        }
 
-    PlatformDependent.copyMemory(
-      keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen);
-    currentPagePosition += keyLen;
+        // in-memory sorter will not be used after spilling
+        assert(inMemSorter != null);
+        released += inMemSorter.getMemoryUsage();
+        inMemSorter.free();
+        inMemSorter = null;
+        return released;
+      }
+    }
 
-    PlatformDependent.copyMemory(
-      valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, valueLen);
-    currentPagePosition += valueLen;
+    @Override
+    public boolean hasNext() {
+      return numRecords > 0;
+    }
 
-    freeSpaceInCurrentPage -= totalSpaceRequired;
-    inMemSorter.insertRecord(recordAddress, prefix);
+    @Override
+    public void loadNext() throws IOException {
+      synchronized (this) {
+        loaded = true;
+        if (nextUpstream != null) {
+          // Just consumed the last record from in memory iterator
+          if (lastPage != null) {
+            freePage(lastPage);
+            lastPage = null;
+          }
+          upstream = nextUpstream;
+          nextUpstream = null;
+        }
+        numRecords--;
+        upstream.loadNext();
+      }
+    }
+
+    @Override
+    public Object getBaseObject() {
+      return upstream.getBaseObject();
+    }
+
+    @Override
+    public long getBaseOffset() {
+      return upstream.getBaseOffset();
+    }
+
+    @Override
+    public int getRecordLength() {
+      return upstream.getRecordLength();
+    }
+
+    @Override
+    public long getKeyPrefix() {
+      return upstream.getKeyPrefix();
+    }
   }
 
-  public UnsafeSorterIterator getSortedIterator() throws IOException {
-    final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
-    int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+  /**
+   * Returns a iterator, which will return the rows in the order as inserted.
+   *
+   * It is the caller's responsibility to call `cleanupResources()`
+   * after consuming this iterator.
+   */
+  public UnsafeSorterIterator getIterator() throws IOException {
     if (spillWriters.isEmpty()) {
-      return inMemoryIterator;
+      assert(inMemSorter != null);
+      return inMemSorter.getIterator();
     } else {
-      final UnsafeSorterSpillMerger spillMerger =
-        new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+      LinkedList queue = new LinkedList<>();
       for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
-        spillMerger.addSpill(spillWriter.getReader(blockManager));
+        queue.add(spillWriter.getReader(blockManager));
       }
-      spillWriters.clear();
-      if (inMemoryIterator.hasNext()) {
-        spillMerger.addSpill(inMemoryIterator);
+      if (inMemSorter != null) {
+        queue.add(inMemSorter.getIterator());
       }
-      return spillMerger.getSortedIterator();
+      return new ChainedIterator(queue);
     }
   }
+
+  /**
+   * Chain multiple UnsafeSorterIterator together as single one.
+   */
+  class ChainedIterator extends UnsafeSorterIterator {
+
+    private final Queue iterators;
+    private UnsafeSorterIterator current;
+
+    public ChainedIterator(Queue iterators) {
+      assert iterators.size() > 0;
+      this.iterators = iterators;
+      this.current = iterators.remove();
+    }
+
+    @Override
+    public boolean hasNext() {
+      while (!current.hasNext() && !iterators.isEmpty()) {
+        current = iterators.remove();
+      }
+      return current.hasNext();
+    }
+
+    @Override
+    public void loadNext() throws IOException {
+      current.loadNext();
+    }
+
+    @Override
+    public Object getBaseObject() { return current.getBaseObject(); }
+
+    @Override
+    public long getBaseOffset() { return current.getBaseOffset(); }
+
+    @Override
+    public int getRecordLength() { return current.getRecordLength(); }
+
+    @Override
+    public long getKeyPrefix() { return current.getKeyPrefix(); }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 313146539190..c16cbce9a0f6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,9 +19,11 @@
 
 import java.util.Comparator;
 
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.util.collection.Sorter;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
 
 /**
  * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
@@ -62,58 +64,84 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
     }
   }
 
+  private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
-  private final Sorter sorter;
+  private final Sorter sorter;
   private final Comparator sortComparator;
 
   /**
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
    */
-  private long[] pointerArray;
+  private LongArray array;
 
   /**
    * The position in the sort buffer where new records can be inserted.
    */
-  private int pointerArrayInsertPosition = 0;
+  private int pos = 0;
 
   public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
+    final TaskMemoryManager memoryManager,
+    final RecordComparator recordComparator,
+    final PrefixComparator prefixComparator,
+    int initialSize) {
+    this(consumer, memoryManager, recordComparator, prefixComparator,
+      consumer.allocateArray(initialSize * 2));
+  }
+
+  public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
       final TaskMemoryManager memoryManager,
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
-      int initialSize) {
-    assert (initialSize > 0);
-    this.pointerArray = new long[initialSize * 2];
+      LongArray array) {
+    this.consumer = consumer;
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    this.array = array;
+  }
+
+  /**
+   * Free the memory used by pointer array.
+   */
+  public void free() {
+    consumer.freeArray(array);
+    array = null;
+  }
+
+  public void reset() {
+    pos = 0;
   }
 
   /**
    * @return the number of records that have been inserted into this sorter.
    */
   public int numRecords() {
-    return pointerArrayInsertPosition / 2;
+    return pos / 2;
   }
 
   public long getMemoryUsage() {
-    return pointerArray.length * 8L;
-  }
-
-  static long getMemoryRequirementsForPointerArray(long numEntries) {
-    return numEntries * 2L * 8L;
+    return array.size() * 8L;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 2 < pointerArray.length;
+    return pos + 2 <= array.size();
   }
 
-  public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
-    // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+  public void expandPointerArray(LongArray newArray) {
+    if (newArray.size() < array.size()) {
+      throw new OutOfMemoryError("Not enough memory to grow pointer array");
+    }
+    Platform.copyMemory(
+      array.getBaseObject(),
+      array.getBaseOffset(),
+      newArray.getBaseObject(),
+      newArray.getBaseOffset(),
+      array.size() * 8L);
+    consumer.freeArray(array);
+    array = newArray;
   }
 
   /**
@@ -125,47 +153,55 @@ public void expandPointerArray() {
    */
   public void insertRecord(long recordPointer, long keyPrefix) {
     if (!hasSpaceForAnotherRecord()) {
-      expandPointerArray();
+      expandPointerArray(consumer.allocateArray(array.size() * 2));
     }
-    pointerArray[pointerArrayInsertPosition] = recordPointer;
-    pointerArrayInsertPosition++;
-    pointerArray[pointerArrayInsertPosition] = keyPrefix;
-    pointerArrayInsertPosition++;
+    array.set(pos, recordPointer);
+    pos++;
+    array.set(pos, keyPrefix);
+    pos++;
   }
 
-  private static final class SortedIterator extends UnsafeSorterIterator {
+  public final class SortedIterator extends UnsafeSorterIterator {
 
-    private final TaskMemoryManager memoryManager;
-    private final int sortBufferInsertPosition;
-    private final long[] sortBuffer;
-    private int position = 0;
+    private final int numRecords;
+    private int position;
     private Object baseObject;
     private long baseOffset;
     private long keyPrefix;
     private int recordLength;
 
-    SortedIterator(
-        TaskMemoryManager memoryManager,
-        int sortBufferInsertPosition,
-        long[] sortBuffer) {
-      this.memoryManager = memoryManager;
-      this.sortBufferInsertPosition = sortBufferInsertPosition;
-      this.sortBuffer = sortBuffer;
+    private SortedIterator(int numRecords) {
+      this.numRecords = numRecords;
+      this.position = 0;
+    }
+
+    public SortedIterator clone() {
+      SortedIterator iter = new SortedIterator(numRecords);
+      iter.position = position;
+      iter.baseObject = baseObject;
+      iter.baseOffset = baseOffset;
+      iter.keyPrefix = keyPrefix;
+      iter.recordLength = recordLength;
+      return iter;
     }
 
     @Override
     public boolean hasNext() {
-      return position < sortBufferInsertPosition;
+      return position / 2 < numRecords;
+    }
+
+    public int numRecordsLeft() {
+      return numRecords - position / 2;
     }
 
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
-      final long recordPointer = sortBuffer[position];
+      final long recordPointer = array.get(position);
       baseObject = memoryManager.getPage(recordPointer);
       baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4;  // Skip over record length
-      recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
-      keyPrefix = sortBuffer[position + 1];
+      recordLength = Platform.getInt(baseObject, baseOffset - 4);
+      keyPrefix = array.get(position + 1);
       position += 2;
     }
 
@@ -186,8 +222,15 @@ public void loadNext() {
    * Return an iterator over record pointers in sorted order. For efficiency, all calls to
    * {@code next()} will return the same mutable object.
    */
-  public UnsafeSorterIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
-    return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+  public SortedIterator getSortedIterator() {
+    sorter.sort(array, 0, pos / 2, sortComparator);
+    return new SortedIterator(pos / 2);
+  }
+
+  /**
+   * Returns an iterator over record pointers in original order (inserted).
+   */
+  public SortedIterator getIterator() {
+    return new SortedIterator(pos / 2);
   }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d09c728a7a63..d3137f5f31c2 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -17,6 +17,9 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 /**
@@ -26,14 +29,14 @@
  * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
  * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
  */
-final class UnsafeSortDataFormat extends SortDataFormat {
+final class UnsafeSortDataFormat extends SortDataFormat {
 
   public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
 
   private UnsafeSortDataFormat() { }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
     // Since we re-use keys, this method shouldn't be called.
     throw new UnsupportedOperationException();
   }
@@ -44,37 +47,43 @@ public RecordPointerAndKeyPrefix newKey() {
   }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
-    reuse.recordPointer = data[pos * 2];
-    reuse.keyPrefix = data[pos * 2 + 1];
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) {
+    reuse.recordPointer = data.get(pos * 2);
+    reuse.keyPrefix = data.get(pos * 2 + 1);
     return reuse;
   }
 
   @Override
-  public void swap(long[] data, int pos0, int pos1) {
-    long tempPointer = data[pos0 * 2];
-    long tempKeyPrefix = data[pos0 * 2 + 1];
-    data[pos0 * 2] = data[pos1 * 2];
-    data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
-    data[pos1 * 2] = tempPointer;
-    data[pos1 * 2 + 1] = tempKeyPrefix;
+  public void swap(LongArray data, int pos0, int pos1) {
+    long tempPointer = data.get(pos0 * 2);
+    long tempKeyPrefix = data.get(pos0 * 2 + 1);
+    data.set(pos0 * 2, data.get(pos1 * 2));
+    data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
+    data.set(pos1 * 2, tempPointer);
+    data.set(pos1 * 2 + 1, tempKeyPrefix);
   }
 
   @Override
-  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
-    dst[dstPos * 2] = src[srcPos * 2];
-    dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+  public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+    dst.set(dstPos * 2, src.get(srcPos * 2));
+    dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
   }
 
   @Override
-  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
-    System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+  public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
+    Platform.copyMemory(
+      src.getBaseObject(),
+      src.getBaseOffset() + srcPos * 16,
+      dst.getBaseObject(),
+      dst.getBaseOffset() + dstPos * 16,
+      length * 16);
   }
 
   @Override
-  public long[] allocate(int length) {
+  public LongArray allocate(int length) {
     assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
-    return new long[length * 2];
+    // This is used as temporary buffer, it's fine to allocate from JVM heap.
+    return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
   }
 
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 8272c2a5be0d..3874a9f9cbdb 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -47,11 +47,19 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
     priorityQueue = new PriorityQueue(numSpills, comparator);
   }
 
-  public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+  /**
+   * Add an UnsafeSorterIterator to this merger
+   */
+  public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException {
     if (spillReader.hasNext()) {
+      // We only add the spillReader to the priorityQueue if it is not empty. We do this to
+      // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
+      // does not return wrong result because hasNext will returns true
+      // at least priorityQueue.size() times. If we allow n spillReaders in the
+      // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
       spillReader.loadNext();
+      priorityQueue.add(spillReader);
     }
-    priorityQueue.add(spillReader);
   }
 
   public UnsafeSorterIterator getSortedIterator() throws IOException {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index ca1ccedc93c8..dcb13e6581e5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,18 +20,18 @@
 import java.io.*;
 
 import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
 
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
 
 /**
  * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
  * of the file format).
  */
-final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
 
-  private final File file;
   private InputStream in;
   private DataInputStream din;
 
@@ -42,18 +42,22 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
 
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
-  private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+  private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
 
   public UnsafeSorterSpillReader(
       BlockManager blockManager,
       File file,
       BlockId blockId) throws IOException {
     assert (file.length() > 0);
-    this.file = file;
     final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
-    this.in = blockManager.wrapForCompression(blockId, bs);
-    this.din = new DataInputStream(this.in);
-    numRecordsRemaining = din.readInt();
+    try {
+      this.in = blockManager.wrapForCompression(blockId, bs);
+      this.din = new DataInputStream(this.in);
+      numRecordsRemaining = din.readInt();
+    } catch (IOException e) {
+      Closeables.close(bs, /* swallowIOException = */ true);
+      throw e;
+    }
   }
 
   @Override
@@ -72,10 +76,7 @@ public void loadNext() throws IOException {
     ByteStreams.readFully(in, arr, 0, recordLength);
     numRecordsRemaining--;
     if (numRecordsRemaining == 0) {
-      in.close();
-      file.delete();
-      in = null;
-      din = null;
+      close();
     }
   }
 
@@ -98,4 +99,16 @@ public int getRecordLength() {
   public long getKeyPrefix() {
     return keyPrefix;
   }
+
+  @Override
+  public void close() throws IOException {
+   if (in != null) {
+     try {
+       in.close();
+     } finally {
+       in = null;
+       din = null;
+     }
+   }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 44cf6c756d7c..234e21140a1d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -28,14 +28,14 @@
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempLocalBlockId;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
 
 /**
  * Spills a list of sorted records to disk. Spill files have the following format:
  *
  *   [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
  */
-final class UnsafeSorterSpillWriter {
+public final class UnsafeSorterSpillWriter {
 
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
 
@@ -117,11 +117,11 @@ public void write(
     long recordReadPosition = baseOffset;
     while (dataRemaining > 0) {
       final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
-      PlatformDependent.copyMemory(
+      Platform.copyMemory(
         baseObject,
         recordReadPosition,
         writeBuffer,
-        PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+        Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
         toTransfer);
       writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
       recordReadPosition += toTransfer;
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
deleted file mode 100644
index 689afea64f8d..000000000000
--- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
+++ /dev/null
@@ -1,16 +0,0 @@
-# Set everything to be logged to the console
-log4j.rootCategory=WARN, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.target=System.err
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-
-# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
-log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
-log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
-
-# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
-log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
-log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 27006e45e932..0750488e4adf 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
 # Set everything to be logged to the console
 log4j.rootCategory=INFO, console
 log4j.appender.console=org.apache.log4j.ConsoleAppender
@@ -5,6 +22,11 @@ log4j.appender.console.target=System.err
 log4j.appender.console.layout=org.apache.log4j.PatternLayout
 log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
 
+# Set the default spark-shell log level to WARN. When running the spark-shell, the
+# log level for this class is used to overwrite the root logger's log level, so that
+# the user can have different defaults for the shell and regular Spark apps.
+log4j.logger.org.apache.spark.repl.Main=WARN
+
 # Settings to quiet third party logs that are too verbose
 log4j.logger.org.spark-project.jetty=WARN
 log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
index dde6069000bc..ff241470f32d 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
@@ -89,7 +89,7 @@ sorttable = {
         // make it clickable to sort
         headrow[i].sorttable_columnindex = i;
         headrow[i].sorttable_tbody = table.tBodies[0];
-        dean_addEvent(headrow[i],"click", function(e) {
+        dean_addEvent(headrow[i],"click", sorttable.innerSortFunction = function(e) {
 
           if (this.className.search(/\bsorttable_sorted\b/) != -1) {
             // if we're already sorted by this column, just 
@@ -169,7 +169,7 @@ sorttable = {
     for (var i=0; i
" +
+            stackTraceText +  "
") + } else { + if (!forceAdd) { + stackTrace.remove() + } + } +} + +function expandAllThreadStackTrace(toggleButton) { + $('.accordion-heading').each(function() { + //get thread ID + if (!$(this).hasClass("hidden")) { + var trId = $(this).attr('id').match(/thread_([0-9]+)_tr/m)[1] + toggleThreadStackTrace(trId, true) + } + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden') + } +} + +function collapseAllThreadStackTrace(toggleButton) { + $('.accordion-body').each(function() { + $(this).remove() + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden'); + } +} + + +// inOrOut - true: over, false: out +function onMouseOverAndOut(threadId) { + $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); +} + +function onSearchStringChange() { + var searchString = $('#search').val().toLowerCase(); + //remove the stacktrace + collapseAllThreadStackTrace(false) + if (searchString.length == 0) { + $('tr').each(function() { + $(this).removeClass('hidden') + }) + } else { + $('tr').each(function(){ + if($(this).attr('id') && $(this).attr('id').match(/thread_[0-9]+_tr/) ) { + var children = $(this).children() + var found = false + for (i = 0; i < children.length; i++) { + if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) { + found = true + } + } + if (found) { + $(this).removeClass('hidden') + } else { + $(this).addClass('hidden') + } + } + }); + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b1cef4704224..b54e33a96fa2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -16,14 +16,9 @@ */ .navbar { - height: 50px; font-size: 15px; margin-bottom: 15px; - min-width: 1200px -} - -.navbar .navbar-inner { - height: 50px; + min-width: 600px; } .navbar .brand { @@ -46,6 +41,7 @@ .navbar-text { height: 50px; line-height: 3.3; + white-space: nowrap; } table.sortable thead { @@ -207,7 +203,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time { +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } @@ -224,3 +220,9 @@ span.additional-metric-title { a.expandbutton { cursor: pointer; } + +.threaddump-td-mouseover { + background-color: #49535a !important; + color: white; + cursor:pointer; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index eb75f26718e1..5592b75afb75 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils * @tparam T partial data that can be added in */ class Accumulable[R, T] private[spark] ( - @transient initialValue: R, + initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean) @@ -152,8 +152,15 @@ class Accumulable[R, T] private[spark] ( in.defaultReadObject() value_ = zero deserialized = true + // Automatically register the accumulator when it is deserialized with the task closure. + // + // Note internal accumulators sent with task are deserialized before the TaskContext is created + // and are registered in the TaskContext constructor. Other internal accumulators, such SQL + // metrics, still need to register here. val taskContext = TaskContext.get() - taskContext.registerAccumulator(this) + if (taskContext != null) { + taskContext.registerAccumulator(this) + } } override def toString: String = if (value_ == null) "null" else value_.toString @@ -248,10 +255,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T, T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient private[spark] val initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -342,3 +359,41 @@ private[spark] object Accumulators extends Logging { } } + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceeb58075d34..7196e57d5d2e 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,68 +34,39 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } - combiners.iterator + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator + } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) } } } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d23c1533db75..bc732535fed8 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.concurrent.{TimeUnit, ScheduledExecutorService} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Classes that represent cleaning tasks. @@ -66,6 +67,20 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + private val periodicGCService: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("context-cleaner-periodic-gc") + + /** + * How often to trigger a garbage collection in this JVM. + * + * This context cleaner triggers cleanups only when weak references are garbage collected. + * In long-running applications with large driver JVMs, where there is little memory pressure + * on the driver, this may happen very occasionally or not at all. Not cleaning at all may + * lead to executors running out of disk space after a while. + */ + private val periodicGCInterval = + sc.conf.getTimeAsSeconds("spark.cleaner.periodicGC.interval", "30min") + /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -104,6 +119,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() + periodicGCService.scheduleAtFixedRate(new Runnable { + override def run(): Unit = System.gc() + }, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS) } /** @@ -119,6 +137,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.interrupt() } cleaningThread.join() + periodicGCService.shutdown() } /** Register a RDD for cleanup when it is garbage collected. */ diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fc8cdde9348e..9aafc9eb1cde 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -65,8 +67,8 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) */ @DeveloperApi -class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], +class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, @@ -76,6 +78,13 @@ class ShuffleDependency[K, V, C]( override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] + private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName + private[spark] val valueClassName: String = reflect.classTag[V].runtimeClass.getName + // Note: It's possible that the combiner class tag is null, if the combineByKey + // methods in PairRDDFunctions are used instead of combineByKeyWithClassTag. + private[spark] val combinerClassName: Option[String] = + Option(reflect.classTag[C]).map(_.runtimeClass.getName) + val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1877aaf2cac5..6176e258989d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -89,6 +89,8 @@ private[spark] class ExecutorAllocationManager( private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) + private val initialNumExecutors = conf.getInt("spark.dynamicAllocation.initialExecutors", + minNumExecutors) // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -121,8 +123,7 @@ private[spark] class ExecutorAllocationManager( // The desired number of executors at this moment in time. If all our executors were to die, this // is the number of executors we would immediately want from the cluster manager. - private var numExecutorsTarget = - conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + private var numExecutorsTarget = initialNumExecutors // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -240,6 +241,19 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } + /** + * Reset the allocation manager to the initial state. Currently this will only be called in + * yarn-client mode when AM re-registers after a failure. + */ + def reset(): Unit = synchronized { + initializing = true + numExecutorsTarget = initialNumExecutors + numExecutorsToAdd = 1 + + executorsPendingToRemove.clear() + removeTimes.clear() + } + /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. @@ -370,6 +384,7 @@ private[spark] class ExecutorAllocationManager( } else { logWarning( s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") + numExecutorsTarget = oldNumExecutorsTarget 0 } } @@ -509,6 +524,7 @@ private[spark] class ExecutorAllocationManager( private def onExecutorBusy(executorId: String): Unit = synchronized { logDebug(s"Clearing idle timer for $executorId because it is now running a task") removeTimes.remove(executorId) + executorsPendingToRemove.remove(executorId) } /** @@ -599,14 +615,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -618,6 +628,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -628,6 +640,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130..2a8220ff4009 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -20,13 +20,15 @@ package org.apache.spark import java.util.Collections import java.util.concurrent.TimeUnit +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.{Failure, Try} /** * A future for the result of an action to support cancellation. This is an extension of the @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] { * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@DeveloperApi class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.completionFuture.ready(atMost) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.completionFuture.ready(atMost) + assert(value.isDefined, "Future has not completed properly") + value.get.get } override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + jobWaiter.completionFuture onComplete {_ => func(value.get)} } override def isCompleted: Boolean = jobWaiter.jobFinished override def isCancelled: Boolean = _cancelled - override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) - } - } + override def value: Option[Try[T]] = + jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } +/** + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ +@DeveloperApi +trait JobSubmitter { + /** + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): FutureAction[R] +} + + /** * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending + * jobs. */ -class ComplexFutureAction[T] extends FutureAction[T] { +@DeveloperApi +class ComplexFutureAction[T](run : JobSubmitter => Future[T]) + extends FutureAction[T] { self => - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ + @volatile private var _cancelled = false - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - @volatile private var jobs: Seq[Int] = Nil + @volatile private var subActions: List[FutureAction[_]] = Nil // A promise used to signal the future. - private val p = promise[T]() + private val p = Promise[T]().tryCompleteWith(run(jobSubmitter)) - override def cancel(): Unit = this.synchronized { + override def cancel(): Unit = synchronized { _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - // This lock guarantees when calling `thread.interrupt()` in `cancel`, - // thread won't be set to null. - ComplexFutureAction.this.synchronized { - thread = null - } - } - } - this + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel()) } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ - def runJob[T, U, R]( + private def jobSubmitter = new JobSubmitter { + def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { + resultFunc: => R): FutureAction[R] = self.synchronized { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. if (!isCancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + val job = rdd.context.submitJob( + rdd, + processPartition, + partitions, + resultHandler, + resultFunc) + subActions = job :: subActions + job } else { throw new SparkException("Action has been cancelled") } } - - this.jobs = jobs ++ job.jobIds - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") - } } override def isCancelled: Boolean = _cancelled @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds: Seq[Int] = jobs + def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) } + private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S Await.ready(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) - case Failure(exception) => + case scala.util.Failure(exception) => if (isCancelled) { throw new CancellationException("Job cancelled").initCause(exception) } else { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ee60d697d879..1f1f0b75de5f 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable +import scala.concurrent.Future import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} @@ -147,11 +148,31 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) } } + /** + * Send ExecutorRegistered to the event loop to add a new executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def addExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRegistered(executorId))) + } + /** * If the heartbeat receiver is not stopped, notify it of executor registrations. */ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + addExecutor(executorAdded.executorId) + } + + /** + * Send ExecutorRemoved to the event loop to remove a executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def removeExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRemoved(executorId))) } /** @@ -165,7 +186,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) * and expire it with loud error messages. */ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + removeExecutor(executorRemoved.executorId) } private def expireDeadHosts(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7cf7bc0dc681..46f9f9e9af7d 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -63,12 +63,17 @@ private[spark] class HttpFileServer( def addFile(file: File) : String = { addFileToDir(file, fileDir) - serverUri + "/files/" + file.getName + serverUri + "/files/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addJar(file: File) : String = { addFileToDir(file, jarDir) - serverUri + "/jars/" + file.getName + serverUri + "/jars/" + Utils.encodeFileNameToURIRawPath(file.getName) + } + + def addDirectory(path: String, resourceBase: String): String = { + httpServer.addDirectory(path, resourceBase) + serverUri + path } def addFileToDir(file: File, dir: File) : String = { @@ -80,7 +85,7 @@ private[spark] class HttpFileServer( throw new IllegalArgumentException(s"$file cannot be a directory.") } Files.copy(file, new File(dir, file.getName)) - dir + "/" + file.getName + dir + "/" + Utils.encodeFileNameToURIRawPath(file.getName) } } diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 8de3a6c04df3..faa3ef3d7561 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -23,10 +23,9 @@ import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} - import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} +import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils @@ -52,6 +51,11 @@ private[spark] class HttpServer( private var server: Server = null private var port: Int = requestedPort + private val servlets = { + val handler = new ServletContextHandler() + handler.setContextPath("/") + handler + } def start() { if (server != null) { @@ -65,6 +69,14 @@ private[spark] class HttpServer( } } + def addDirectory(contextPath: String, resourceBase: String): Unit = { + val holder = new ServletHolder() + holder.setInitParameter("resourceBase", resourceBase) + holder.setInitParameter("pathInfoOnly", "true") + holder.setServlet(new DefaultServlet()) + servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*") + } + /** * Actually start the HTTP server on the given port. * @@ -85,21 +97,17 @@ private[spark] class HttpServer( val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + addDirectory("/", resourceBase.getAbsolutePath) if (securityManager.isAuthenticationEnabled()) { logDebug("HttpServer is using security") val sh = setupSecurityHandler(securityManager) // make sure we go through security handler to get resources - sh.setHandler(handlerList) + sh.setHandler(servlets) server.setHandler(sh) } else { logDebug("HttpServer is not using security") - server.setHandler(handlerList) + server.setHandler(servlets) } server.start() diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index f0598816d6c0..e35e158c7e8a 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,15 +17,14 @@ package org.apache.spark -import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Private import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. @@ -33,7 +32,7 @@ import org.apache.spark.util.Utils * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility. * This will likely be changed or removed in future releases. */ -@DeveloperApi +@Private trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine @@ -120,30 +119,31 @@ trait Logging { val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println if (!log4j12Initialized) { - // scalastyle:off println - if (Utils.isInInterpreter) { - val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" - Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") - System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") - case None => - System.err.println(s"Spark was unable to load $replDefaultLogProps") - } - } else { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") - } + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") } - // scalastyle:on println } + + if (Utils.isInInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 000000000000..f8a6f1d0d8cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 92218832d256..72355cdfa68b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} -import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} @@ -44,10 +45,10 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = context.sender.address.hostPort + val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.size + val serializedSize = mapOutputStatuses.length if (serializedSize > maxAkkaFrameSize) { val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." @@ -132,13 +133,57 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") - val startTime = System.currentTimeMillis + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + } + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range). + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -160,7 +205,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -175,22 +220,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } - logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } @@ -235,6 +276,21 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch + /** Whether to compute locality preferences for reduce tasks */ + private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + + // Number of map and reduce tasks above which we do not assign preferred locations based on map + // output sizes. We limit the size of jobs for which assign preferred locations as computing the + // top locations by size becomes expensive. + private val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. Making this larger will focus on fewer locations where most data + // can be read locally, but may lead to more delay in scheduling if those locations are busy. + private val REDUCER_PREF_LOCS_FRACTION = 0.2 + /** * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver, * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). @@ -295,6 +351,30 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return the preferred hosts on which to run the given map output partition in a given shuffle, + * i.e. the nodes that the most outputs for that partition are on. + * + * @param dep shuffle dependency object + * @param partitionId map output partition that we want to read + * @return a sequence of host names + */ + def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) + : Seq[String] = { + if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } + /** * Return a list of locations that each have fraction of map output greater than the specified * threshold. @@ -398,7 +478,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]] + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala } private[spark] object MapOutputTracker extends Logging { @@ -433,23 +513,25 @@ private[spark] object MapOutputTracker extends Logging { } /** - * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block - * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that - * block manager. + * Given an array of map statuses and a range of map output partitions, returns a sequence that, + * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes + * stored at that block manager. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. * * @param shuffleId Identifier for the shuffle - * @param reduceId Identifier for the reduce task + * @param startPartition Start of map output partition ID range (included in range) + * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ private def convertMapStatuses( shuffleId: Int, - reduceId: Int, + startPartition: Int, + endPartition: Int, statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] @@ -457,10 +539,12 @@ private[spark] object MapOutputTracker extends Logging { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + for (part <- startPartition until endPartition) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + } } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 4b9d59975bdc..ef9a2dab1c10 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -104,8 +104,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K, V]], + partitions: Int, + rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } @@ -291,7 +291,7 @@ private[spark] object RangePartitioner { while ((i < numCandidates) && (j < partitions - 1)) { val (key, weight) = ordered(i) cumWeight += weight - if (cumWeight > target) { + if (cumWeight >= target) { // Skip duplicate values. if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { bounds += key diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 32df42d57dbd..3b9c885bf97a 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,9 +17,11 @@ package org.apache.spark -import java.io.{File, FileInputStream} -import java.security.{KeyStore, NoSuchAlgorithmException} -import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} +import java.io.File +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext + +import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -79,7 +81,6 @@ private[spark] case class SSLOptions( * object. It can be used then to compose the ultimate Akka configuration. */ def createAkkaConfig: Option[Config] = { - import scala.collection.JavaConversions._ if (enabled) { Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", @@ -97,7 +98,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 673ef49e7c1c..64e483e38477 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,11 +17,13 @@ package org.apache.spark +import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} -import java.security.KeyStore +import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate import javax.net.ssl._ +import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text @@ -130,15 +132,16 @@ import org.apache.spark.util.Utils * * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. * - * For Yarn deployments, the secret is automatically generated using the Akka remote - * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed - * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels - * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn - * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn - * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there - * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use - * filters to do authentication. That authentication then happens via the ResourceManager Proxy - * and Spark will use that to do authorization against the view acls. + * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop + * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to + * support different levels of protection. See the Hadoop documentation for more details. Each + * Spark application on YARN gets a different shared secret. + * + * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user + * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web + * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That + * authentication then happens via the ResourceManager Proxy and Spark will use that to do + * authorization against the view acls. * * For other Spark deployments, the shared secret must be specified via the * spark.authenticate.secret config. @@ -189,8 +192,7 @@ import org.apache.spark.util.Utils private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { - // key used to store the spark secret in the Hadoop UGI - private val sparkSecretLookupKey = "sparkCookie" + import SecurityManager._ private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 @@ -310,7 +312,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } - def getViewAcls: String = viewAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getViewAcls: String = { + if (viewAcls.contains("*")) { + "*" + } else { + viewAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -321,7 +332,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } - def getModifyAcls: String = modifyAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getModifyAcls: String = { + if (modifyAcls.contains("*")) { + "*" + } else { + modifyAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -347,33 +367,38 @@ private[spark] class SecurityManager(sparkConf: SparkConf) * we throw an exception. */ private def generateSecretKey(): String = { - if (!isAuthenticationEnabled) return null - // first check to see if the secret is already set, else generate a new one if on yarn - val sCookie = if (SparkHadoopUtil.get.isYarnMode) { - val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) - if (secretKey != null) { - logDebug("in yarn mode, getting secret from credentials") - return new Text(secretKey).toString + if (!isAuthenticationEnabled) { + null + } else if (SparkHadoopUtil.get.isYarnMode) { + // In YARN mode, the secure cookie will be created by the driver and stashed in the + // user's credentials, where executors can get it. The check for an array of size 0 + // is because of the test code in YarnSparkHadoopUtilSuite. + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY) + if (secretKey == null || secretKey.length == 0) { + logDebug("generateSecretKey: yarn mode, secret key from credentials is null") + val rnd = new SecureRandom() + val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE + val secret = new Array[Byte](length) + rnd.nextBytes(secret) + + val cookie = HashCodes.fromBytes(secret).toString() + SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie) + cookie } else { - logDebug("getSecretKey: yarn mode, secret key from credentials is null") + new Text(secretKey).toString } - val cookie = akka.util.Crypt.generateSecureCookie - // if we generated the secret then we must be the first so lets set it so t - // gets used by everyone else - SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) - logInfo("adding secret to credentials in yarn mode") - cookie } else { // user must have set spark.authenticate.secret config // For Master/Worker, auth secret is in conf; for Executors, it is in env variable - sys.env.get(SecurityManager.ENV_AUTH_SECRET) + Option(sparkConf.getenv(SecurityManager.ENV_AUTH_SECRET)) .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value - case None => throw new Exception("Error: a secret key must be specified via the " + - SecurityManager.SPARK_AUTH_SECRET_CONF + " config") + case None => + throw new IllegalArgumentException( + "Error: a secret key must be specified via the " + + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } - sCookie } /** @@ -394,7 +419,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) + !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") } /** @@ -409,7 +434,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) + !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") } @@ -457,6 +482,9 @@ private[spark] object SecurityManager { val SPARK_AUTH_CONF: String = "spark.authenticate" val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" // This is used to set auth secret to an executor's env variable. It should have the same - // value as SPARK_AUTH_SECERET_CONF set in SparkConf + // value as SPARK_AUTH_SECRET_CONF set in SparkConf val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" + + // key used to store the spark secret in the Hadoop UGI + val SECRET_LOOKUP_KEY = "sparkCookie" } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 08bab4bf2739..d3384fb29773 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -249,6 +249,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.byteStringAsBytes(get(key, defaultValue)) } + /** + * Get a size parameter as bytes, falling back to a default if not set. + */ + def getSizeAsBytes(key: String, defaultValue: Long): Long = { + Utils.byteStringAsBytes(get(key, defaultValue + "B")) + } + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -382,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -410,16 +418,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate memory fractions - val memoryKeys = Seq( + val deprecatedMemoryKeys = Seq( "spark.storage.memoryFraction", "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") + val memoryKeys = Seq( + "spark.memory.fraction", + "spark.memory.storageFraction") ++ + deprecatedMemoryKeys for (key <- memoryKeys) { val value = getDouble(key, 0.5) if (value > 1 || value < 0) { - throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + throw new IllegalArgumentException(s"$key should be between 0 and 1 (was '$value').") + } + } + + // Warn against deprecated memory fractions (unless legacy memory management mode is enabled) + val legacyMemoryManagementKey = "spark.memory.useLegacyMode" + val legacyMemoryManagement = getBoolean(legacyMemoryManagementKey, false) + if (!legacyMemoryManagement) { + val keyset = deprecatedMemoryKeys.toSet + val detected = settings.keys().asScala.filter(keyset.contains) + if (detected.nonEmpty) { + logWarning("Detected deprecated memory fraction settings: " + + detected.mkString("[", ", ", "]") + ". As of Spark 1.6, execution and storage " + + "memory management are unified. All memory fractions used in the old model are " + + "now deprecated and no longer read. If you wish to use the old memory management, " + + s"you may explicitly enable `$legacyMemoryManagementKey` (not recommended).") } } @@ -469,6 +496,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** @@ -550,7 +595,11 @@ private[spark] object SparkConf extends Logging { "spark.rpc.lookupTimeout" -> Seq( AlternateConfig("spark.akka.lookupTimeout", "1.4")), "spark.streaming.fileStream.minRememberDuration" -> Seq( - AlternateConfig("spark.streaming.minRememberDuration", "1.5")) + AlternateConfig("spark.streaming.minRememberDuration", "1.5")), + "spark.yarn.max.executor.failures" -> Seq( + AlternateConfig("spark.yarn.max.worker.failures", "1.5")), + "spark.memory.offHeap.enabled" -> Seq( + AlternateConfig("spark.unsafe.offHeap", "1.6")) ) /** @@ -574,7 +623,7 @@ private[spark] object SparkConf extends Logging { /** * Return whether the given config should be passed to an executor on start-up. * - * Certain akka and authentication configs are required of the executor when it connects to + * Certain akka and authentication configs are required from the executor when it connects to * the scheduler, while the rest of the spark configs can be inherited from the driver later. */ def isExecutorStartupConf(name: String): Boolean = { @@ -582,6 +631,7 @@ private[spark] object SparkConf extends Logging { name.startsWith("spark.akka") || (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || + name.startsWith("spark.rpc") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4380cf45cc1b..194ecc0a0434 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,13 +26,14 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -44,23 +45,23 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -89,18 +90,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // NOTE: this must be placed at the beginning of the SparkContext constructor. SparkContext.markPartiallyConstructed(this, allowMultipleContexts) - // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, - // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It - // contains a map from hostname to a list of input format splits on the host. - private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() - val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { - throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + val activeContext = SparkContext.activeContext.get() + val activeCreationSite = + if (activeContext == null) { + "(No active SparkContext.)" + } else { + activeContext.creationSite.longForm + } + throw new IllegalStateException( + s"""Cannot call methods on a stopped SparkContext. + |This stopped SparkContext was created at: + | + |${creationSite.longForm} + | + |The currently active SparkContext was created at: + | + |$activeCreationSite + """.stripMargin) } } @@ -114,14 +126,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * :: DeveloperApi :: * Alternative constructor for setting preferred locations where Spark will create executors. * - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. - * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. + * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters + * @param preferredNodeLocationData not used. Left for backward compatibility. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) - this.preferredNodeLocationData = preferredNodeLocationData + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") } /** @@ -143,7 +155,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. + * @param preferredNodeLocationData not used. Left for backward compatibility. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-10921", "1.6.0") def this( master: String, appName: String, @@ -153,7 +167,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) - this.preferredNodeLocationData = preferredNodeLocationData + if (preferredNodeLocationData.nonEmpty) { + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") + } } // NOTE: The below constructors could be consolidated using default arguments. Due to @@ -167,7 +183,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param appName A name for your application, to display on the cluster web UI. */ private[spark] def this(master: String, appName: String) = - this(master, appName, null, Nil, Map(), Map()) + this(master, appName, null, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -177,7 +193,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param sparkHome Location where Spark is installed on cluster nodes. */ private[spark] def this(master: String, appName: String, sparkHome: String) = - this(master, appName, sparkHome, Nil, Map(), Map()) + this(master, appName, sparkHome, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -189,7 +205,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * system or HDFS, HTTP, HTTPS, or FTP URLs. */ private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) = - this(master, appName, sparkHome, jars, Map(), Map()) + this(master, appName, sparkHome, jars, Map()) // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") @@ -256,6 +272,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() + // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus @@ -264,7 +285,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) } private[spark] def env: SparkEnv = _env @@ -338,8 +359,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } @@ -432,6 +457,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) + // If running the REPL, register the repl's output dir with the file server. + _conf.getOption("spark.repl.class.outputDir").foreach { path => + val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) + _conf.set("spark.repl.class.uri", replUri) + } + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) _statusTracker = new SparkStatusTracker(this) @@ -507,6 +538,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. @@ -528,7 +560,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) @@ -551,6 +587,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(_dagScheduler.metricsSource) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) @@ -559,7 +596,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -587,11 +625,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get - val endpointRef = env.rpcEnv.setupEndpointRef( - SparkEnv.executorActorSystemName, - RpcAddress(host, port), - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { @@ -629,7 +663,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * [[org.apache.spark.SparkContext.setLocalProperty]]. */ def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) + Option(localProperties.get).map(_.getProperty(key)).orNull /** Set a human readable description of the current job. */ def setJobDescription(value: String) { @@ -831,6 +865,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred, large file is also allowable, but may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -841,19 +878,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], - classOf[String], - classOf[String], + classOf[Text], + classOf[Text], updateConf, - minPartitions).setName(path) + minPartitions).setName(path).map(record => (record._1.toString, record._2.toString)) } /** - * :: Experimental :: - * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -866,7 +901,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ @@ -879,9 +914,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred; very large files may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - @Experimental def binaryFiles( path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { @@ -890,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -901,18 +938,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * :: Experimental :: - * * Load data from a flat binary file, assuming the length of each record is constant. * * '''Note:''' We ensure that the byte array for each record in the resulting RDD * has the provided record length. * - * @param path Directory to the input data files + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param recordLength The length at which to split the records + * @param conf Configuration for setting up the dataset. + * * @return An RDD of data with values, represented as byte arrays */ - @Experimental def binaryRecords( path: String, recordLength: Int, @@ -1069,7 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1344,7 +1381,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1427,7 +1464,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(executorIds) + b.killExecutors(executorIds, replace = false, force = true) case _ => logWarning("Killing executors is only supported in coarse-grained mode") false @@ -1465,7 +1502,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = true) + b.killExecutors(Seq(executorId), replace = true, force = true) case _ => logWarning("Killing executors is only supported in coarse-grained mode") false @@ -1493,8 +1530,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + getRDDStorageInfo(_ => true) + } + + private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) } @@ -1523,7 +1564,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getAllPools: Seq[Schedulable] = { assertNotStopped() // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + taskScheduler.rootPool.schedulableQueue.asScala.toSeq } /** @@ -1567,11 +1608,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { - _executorAllocationManager.foreach { _ => - logWarning( - s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + - s"${rdd.id} will be lost when executors are removed.") - } persistentRdds(rdd.id) = rdd } @@ -1596,7 +1632,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1610,7 +1646,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1621,7 +1657,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") @@ -1660,6 +1696,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop SparkContext within listener thread of" + + " AsynchronousListenerBus") + } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. if (!stopped.compareAndSet(false, true)) { @@ -1667,7 +1707,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } Utils.tryLogNonFatalError { @@ -1723,6 +1763,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } SparkEnv.set(null) } + // Unset YARN mode system env variable, to allow switching between cluster types. + System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } @@ -1768,10 +1810,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * has overridden the call site using `setCallSite()`, this will return the user's version. */ private[spark] def getCallSite(): CallSite = { - Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite => - val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("") - CallSite(shortCallSite, longCallSite) - }.getOrElse(Utils.getCallSite()) + val callSite = Utils.getCallSite() + CallSite( + Option(getLocalProperty(CallSite.SHORT_FORM)).getOrElse(callSite.shortForm), + Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse(callSite.longForm) + ) } /** @@ -1938,10 +1981,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * :: Experimental :: * Submit a job for execution and return a FutureJob holding the result. */ - @Experimental def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, @@ -1962,6 +2003,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. @@ -2518,6 +2576,21 @@ object SparkContext extends Logging { res } + /** + * The number of driver cores to use for execution in local mode, 0 otherwise. + */ + private[spark] def numDriverCores(master: String): Int = { + def convertToInt(threads: String): Int = { + if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + } + master match { + case "local" => 1 + case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) + case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) + case _ => 0 // driver is not used for execution + } + } + /** * Create a task scheduler based on a given master URL. * Return a 2-tuple of the scheduler backend and the task scheduler. @@ -2525,18 +2598,7 @@ object SparkContext extends Logging { private def createTaskScheduler( sc: SparkContext, master: String): (SchedulerBackend, TaskScheduler) = { - // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r + import SparkMasterRegex._ // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 @@ -2652,15 +2714,14 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case mesosUrl @ MESOS_REGEX(_) => + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) - val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) + new CoarseMesosSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) } else { - new MesosSchedulerBackend(scheduler, sc, url) + new MesosSchedulerBackend(scheduler, sc, mesosUrl) } scheduler.initialize(backend) (backend, scheduler) @@ -2671,12 +2732,35 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) + case zkUrl if zkUrl.startsWith("zk://") => + logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") + createTaskScheduler(sc, "mesos://" + zkUrl) + case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") } } } +/** + * A collection of regexes for extracting information from the master string. + */ +private object SparkMasterRegex { + // Regular expression used for local[N] and local[*] master formats + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or mesos://zk:// url + val MESOS_REGEX = """mesos://(.*)""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r +} + /** * A class encapsulating how to convert some type T to Writable. It stores both the Writable class * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index adfece4d6e7c..52acde1b414e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,29 +20,27 @@ package org.apache.spark import java.io.File import java.net.Socket -import akka.actor.ActorSystem - import scala.collection.mutable import scala.util.Properties +import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -58,6 +56,7 @@ import org.apache.spark.util.{RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, + _actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -67,17 +66,15 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val shuffleMemoryManager: ShuffleMemoryManager, - val executorMemoryManager: ExecutorMemoryManager, + val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { // TODO Remove actorSystem @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") - val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = _actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -93,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -101,6 +97,9 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() + if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { + actorSystem.shutdown() + } rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -185,6 +184,7 @@ object SparkEnv extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus, + numCores: Int, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") @@ -197,6 +197,7 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, + numUsableCores = numCores, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -236,8 +237,8 @@ object SparkEnv extends Logging { port: Int, isDriver: Boolean, isLocal: Boolean, + numUsableCores: Int, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver @@ -249,13 +250,34 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, + clientMode = !isDriver) + val actorSystem: ActorSystem = + if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { + rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + } else { + val actorSystemPort = + if (port == 0 || rpcEnv.address == null) { + port + } else { + rpcEnv.address.port + 1 + } + // Create a ActorSystem for legacy codes + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 + } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // In the non-driver case, the RPC env's address may be null since it may not be listening + // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else { + } else if (rpcEnv.address != null) { conf.set("spark.executor.port", rpcEnv.address.port.toString) } @@ -319,21 +341,21 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = new ShuffleMemoryManager(conf) - - val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { - case "netty" => - new NettyBlockTransferService(conf, securityManager, numUsableCores) - case "nio" => - new NioBlockTransferService(conf, securityManager) + val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) + val memoryManager: MemoryManager = + if (useLegacyMemoryManager) { + new StaticMemoryManager(conf, numUsableCores) + } else { + UnifiedMemoryManager(conf, numUsableCores) } + val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), @@ -341,24 +363,13 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, - numUsableCores) + serializer, conf, memoryManager, mapOutputTracker, shuffleManager, + blockTransferService, securityManager, numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -390,18 +401,10 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - val executorMemoryManager: ExecutorMemoryManager = { - val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { - MemoryAllocator.UNSAFE - } else { - MemoryAllocator.HEAP - } - new ExecutorMemoryManager(allocator) - } - val envInstance = new SparkEnv( executorId, rpcEnv, + actorSystem, serializer, closureSerializer, cacheManager, @@ -411,11 +414,9 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, - shuffleMemoryManager, - executorMemoryManager, + memoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 2ebd7a7151a5..977a27bdfe1b 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -30,3 +30,10 @@ class SparkException(message: String, cause: Throwable) */ private[spark] class SparkDriverExecutionException(cause: Throwable) extends SparkException("Execution error", cause) + +/** + * Exception thrown when the main user code is run as a child process (e.g. pyspark) and we want + * the parent SparkSubmit process to exit with the same exit code. + */ +private[spark] case class SparkUserAppException(exitCode: Int) + extends SparkException(s"User application exited with $exitCode") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f5dd36cbcfe6..ac6eaab20d8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(@transient jobConf: JobConf) +class SparkHadoopWriter(jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { @@ -104,8 +104,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5d2c551d5851..af558d6e5b47 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,8 +21,8 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -61,12 +61,12 @@ object TaskContext { protected[spark] def unset(): Unit = taskContext.remove() /** - * Return an empty task context that is not actually used. - * Internal use only. + * An empty task context that does not represent an actual task. */ - private[spark] def empty(): TaskContext = { - new TaskContextImpl(0, 0, 0, 0, null, null) + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) } + } @@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable { * accumulator id and the value of the Map is the latest accumulator local value. */ private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9ee168ae016f..f0ae83a9341b 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -20,9 +20,9 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} private[spark] class TaskContextImpl( @@ -32,6 +32,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -114,4 +115,11 @@ private[spark] class TaskContextImpl( private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { accumulators.mapValues(_.localValue).toMap } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d5..13241b77bf97 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,11 +17,17 @@ package org.apache.spark +import java.io.{ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry @@ -46,6 +52,14 @@ case object Success extends TaskEndReason sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String + + /** + * Whether this task failure should be counted towards the maximum number of times the task is + * allowed to fail before the stage is aborted. Set to false in cases where the task's failure + * was unrelated to the task; for example, if the task failed because the executor it was running + * on was killed. + */ + def countTowardsTaskFailures: Boolean = true } /** @@ -90,6 +104,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +115,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +160,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before @@ -151,9 +203,18 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" + /** + * If a task failed because its attempt to commit was denied, do not count this failure + * towards failing the stage. This is intended to prevent spurious stage failures in cases + * where many speculative tasks are launched and denied to commit. + */ + override def countTowardsTaskFailures: Boolean = false } /** @@ -162,8 +223,22 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String) extends TaskFailedReason { - override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" +case class ExecutorLostFailure( + execId: String, + exitCausedByApp: Boolean = true, + reason: Option[String]) extends TaskFailedReason { + override def toErrorString: String = { + val exitBehavior = if (exitCausedByApp) { + "caused by one of the running tasks" + } else { + "unrelated to the running tasks" + } + s"ExecutorLostFailure (executor ${execId} exited due to an issue ${exitBehavior})" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + + reason.map { r => s" Reason: $r" }.getOrElse("") + } + + override def countTowardsTaskFailures: Boolean = exitCausedByApp } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a1ebbecf93b7..43c89b258f2f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -19,14 +19,20 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} +import java.nio.charset.StandardCharsets +import java.nio.file.Paths +import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils /** @@ -71,22 +77,22 @@ private[spark] object TestUtils { files.foreach { case (k, v) => val entry = new JarEntry(k) jarStream.putNextEntry(entry) - ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream) } jarStream.close() jarFile.toURI.toURL } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -118,14 +124,14 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) } else { Seq() } - compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() + compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) @@ -153,4 +159,51 @@ private[spark] object TestUtils { " @Override public String toString() { return \"" + toStringValue + "\"; }}") createCompiledClass(className, destDir, sourceFile, classpathUrls) } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs spilled. + */ + def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs + * did not spill. + */ + def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + } + +} + + +/** + * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + */ +private class SpillListener extends SparkListener { + private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] + private val spilledStageIds = new mutable.HashSet[Int] + + def numSpilledStages: Int = spilledStageIds.size + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + stageIdToTaskMetrics.getOrElseUpdate( + taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics + } + + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + val stageId = stageComplete.stageInfo.stageId + val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten + val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 + if (spilled) { + spilledStageIds += stageId + } + } } diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 000000000000..af483e361e33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation +import scala.annotation.meta._ + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +@param @field @getter @setter @beanGetter @beanSetter +private[spark] class Since(version: String) extends StaticAnnotation diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index a650df605b92..c32aefac465b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -24,7 +24,6 @@ import scala.reflect.ClassTag import org.apache.spark.Partitioner import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -209,25 +208,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) srdd.meanApprox(timeout, confidence) /** - * :: Experimental :: * Approximate operation to return the mean within a timeout. */ - @Experimental def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] = srdd.sumApprox(timeout, confidence) /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 0ae0b4ec042e..891bcddeac28 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapred.InputSplit @@ -37,7 +37,7 @@ class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index ec4f3964d75e..0f49279f3e64 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapreduce.InputSplit @@ -37,7 +37,7 @@ class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 8441bb3a3047..87deaf20e2b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} @@ -142,7 +141,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). @@ -159,7 +158,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -169,14 +167,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need * two additional passes. */ - @Experimental def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -188,7 +184,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * Use Utils.random.nextLong as the default seed for the random number generator. */ - @Experimental def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) @@ -220,13 +215,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple @@ -239,7 +234,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapSideCombine: Boolean, serializer: Serializer): JavaPairRDD[K, C] = { implicit val ctag: ClassTag[C] = fakeClassTag - fromRDD(rdd.combineByKey( + fromRDD(rdd.combineByKeyWithClassTag( createCombiner, mergeValue, mergeCombiners, @@ -252,13 +247,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. @@ -300,20 +295,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) @@ -768,7 +759,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + def lookup(key: K): JList[V] = rdd.lookup(key).asJava /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( @@ -987,30 +978,27 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { - rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) + rddToPairRDDFunctions(rdd).mapValues(_.asJava) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava)) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava)) } private[spark] def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => - (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava, x._4.asJava)) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 829fae1d1d9b..0e4d7dce0f2f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,7 +21,6 @@ import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} import java.util.{Comparator, List => JList, Iterator => JIterator} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -29,7 +28,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap @@ -59,10 +58,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def splits: JList[Partition] = rdd.partitions.toSeq.asJava /** Set of partitions in this RDD. */ - def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def partitions: JList[Partition] = rdd.partitions.toSeq.asJava + + /** Return the number of partitions in this RDD. */ + @Since("1.6.0") + def getNumPartitions: Int = rdd.getNumPartitions /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) @@ -82,7 +85,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * subclasses of RDD. */ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) + rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -99,7 +102,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -153,7 +156,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -164,7 +167,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -175,7 +178,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -186,7 +189,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -197,7 +200,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -209,7 +212,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -219,14 +222,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(x.asJava))) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + new JavaRDD(rdd.glom().map(_.toSeq.asJava)) /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of @@ -266,13 +269,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) + rdd.pipe(command.asScala) /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + rdd.pipe(command.asScala, env.asScala) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -294,8 +297,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) @@ -333,28 +335,22 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } + def collect(): JList[T] = + rdd.collect().toSeq.asJava /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator(): JIterator[T] = { - import scala.collection.JavaConversions._ - rdd.toLocalIterator - } - + def toLocalIterator(): JIterator[T] = + asJavaIteratorConverter(rdd.toLocalIterator).asJava /** * Return an array that contains all of the elements in this RDD. * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead */ - @Deprecated + @deprecated("use collect()", "1.0.0") def toArray(): JList[T] = collect() /** @@ -363,9 +359,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - import scala.collection.JavaConversions._ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) - res.map(x => new java.util.ArrayList(x.toSeq)).toArray + res.map(_.toSeq.asJava) } /** @@ -445,20 +440,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def count(): Long = rdd.count() /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = rdd.countApprox(timeout, confidence) /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox(timeout: Long): PartialResult[BoundedDouble] = rdd.countApprox(timeout) @@ -489,20 +480,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } + def take(num: Int): JList[T] = + rdd.take(num).toSeq.asJava def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = + rdd.takeSample(withReplacement, num, seed).toSeq.asJava /** * Return the first element in this RDD. @@ -576,21 +561,18 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by - * the specified Comparator[T]. + * the specified Comparator[T] and maintains the order. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements */ def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.top(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** * Returns the top k (largest) elements from this RDD using the - * natural ordering for T. + * natural ordering for T and maintains the order. * @param num k, the number of top elements to return * @return an array of top elements */ @@ -607,10 +589,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -696,7 +675,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { - new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 02e49a853c5f..4f54cd69e217 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,8 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -34,7 +33,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.AccumulatorParam._ -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} @@ -104,7 +102,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) private[spark] val env = sc.env @@ -118,7 +116,7 @@ class JavaSparkContext(val sc: SparkContext) def appName: String = sc.appName - def jars: util.List[String] = sc.jars + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime @@ -142,7 +140,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + sc.parallelize(list.asScala, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -161,7 +159,7 @@ class JavaSparkContext(val sc: SparkContext) : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -170,8 +168,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = @@ -268,8 +265,6 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** - * :: Experimental :: - * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -296,19 +291,15 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ - @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** - * :: Experimental :: - * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ - @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } @@ -519,7 +510,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[T] = first.classTag sc.union(rdds) } @@ -527,7 +518,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[(K, V)] = first.classTag implicit val ctagK: ClassTag[K] = first.kClassTag implicit val ctagV: ClassTag[V] = first.vClassTag @@ -536,7 +527,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index b959b683d167..d2beef2a0dd4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,22 +17,21 @@ package org.apache.spark.api.python -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} /** - * :: Experimental :: * A trait for use with reading custom classes in PySpark. Implement this trait and add custom * transformation code by overriding the convert method. */ -@Experimental trait Converter[T, + U] extends Serializable { def convert(obj: T): U } @@ -68,7 +67,6 @@ private[python] class WritableToJavaConverter( * object representation */ private def convertWritable(writable: Writable): Any = { - import collection.JavaConversions._ writable match { case iw: IntWritable => iw.get() case dw: DoubleWritable => dw.get() @@ -89,9 +87,7 @@ private[python] class WritableToJavaConverter( aw.get().map(convertWritable(_)) case mw: MapWritable => val map = new java.util.HashMap[Any, Any]() - mw.foreach { case (k, v) => - map.put(convertWritable(k), convertWritable(v)) - } + mw.asScala.foreach { case (k, v) => map.put(convertWritable(k), convertWritable(v)) } map case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other @@ -122,7 +118,6 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { * supported out-of-the-box. */ private def convertToWritable(obj: Any): Writable = { - import collection.JavaConversions._ obj match { case i: java.lang.Integer => new IntWritable(i) case d: java.lang.Double => new DoubleWritable(d) @@ -134,7 +129,7 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { case null => NullWritable.get() case map: java.util.Map[_, _] => val mapWritable = new MapWritable() - map.foreach { case (k, v) => + map.asScala.foreach { case (k, v) => mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable @@ -161,9 +156,8 @@ private[python] object PythonHadoopUtil { * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] */ def mapToConf(map: java.util.Map[String, String]): Configuration = { - import collection.JavaConversions._ val conf = new Configuration() - map.foreach{ case (k, v) => conf.set(k, v) } + map.asScala.foreach { case (k, v) => conf.set(k, v) } conf } @@ -172,9 +166,8 @@ private[python] object PythonHadoopUtil { * any matching keys in left */ def mergeConfs(left: Configuration, right: Configuration): Configuration = { - import collection.JavaConversions._ val copy = new Configuration(left) - right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + right.asScala.foreach(entry => copy.set(entry.getKey, entry.getValue)) copy } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 55e563ee968b..8464b578ed09 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,9 +21,10 @@ import java.io._ import java.net._ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,10 +39,9 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( - @transient parent: RDD[_], + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -61,21 +61,49 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { - envVars += ("SPARK_REUSE_WORKER" -> "1") + envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -150,7 +178,7 @@ private[spark] class PythonRDD( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } } @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,19 +242,19 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet + val newBids = broadcastVars.asScala.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -233,7 +264,7 @@ private[spark] class PythonRDD( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars) { + for (broadcast <- broadcastVars.asScala) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -287,7 +318,7 @@ private[spark] class PythonRDD( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } @@ -358,10 +390,10 @@ private[spark] object PythonRDD extends Logging { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, - s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") } /** @@ -601,7 +633,7 @@ private[spark] object PythonRDD extends Logging { * * The thread will terminate after all the data are sent or any exceptions happen. */ - private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -785,7 +817,7 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") @@ -794,7 +826,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added - * by the DAGScheduler's single-threaded actor anyway. + * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ @transient var socket: Socket = _ @@ -819,7 +851,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) - for (array <- val2) { + for (array <- val2.asScala) { out.writeInt(array.length) out.write(array) } @@ -839,7 +871,8 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * write the data into disk after deserialization, then Python can read it from disks. */ // scalastyle:off no.finalize -private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable + with Logging { /** * Read data from disks, then copy it to `out` @@ -875,7 +908,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial if (!path.isEmpty) { val file = new File(path) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file.getPath}") + } } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 90dacaeb9342..292ac4cfc35b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,10 +17,10 @@ package org.apache.spark.api.python -import java.io.{File} +import java.io.File import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) @@ -51,7 +51,14 @@ private[spark] object PythonUtils { * Convert list of T into seq of T (for calling API with varargs) */ def toSeq[T](vs: JList[T]): Seq[T] = { - vs.toList.toSeq + vs.asScala + } + + /** + * Convert list of T into a (Scala) List of T + */ + def toList[T](vs: JList[T]): List[T] = { + vs.asScala.toList } /** @@ -65,6 +72,6 @@ private[spark] object PythonUtils { * Convert java map of K, V into Map of K, V (for calling API with varargs) */ def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { - jm.toMap + jm.asScala.toMap } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e314408c067e..7039b734d2e4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.util.Arrays import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.util.{RedirectThread, Utils} @@ -108,9 +109,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -151,9 +152,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 1f1debcf84ad..fd27276e70bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList} import org.apache.spark.api.java.JavaRDD -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure @@ -214,7 +213,7 @@ private[spark] object SerDeUtil extends Logging { new AutoBatchedPickler(cleaned) } else { val pickle = new Pickler - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 8f30ff9202c8..ee1fb056f0d9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -20,6 +20,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} import java.{util => ju} +import scala.collection.JavaConverters._ + import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ @@ -62,10 +64,9 @@ private[python] class TestInputKeyConverter extends Converter[Any, Any] { } private[python] class TestInputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] - seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + m.keySet.asScala.map(_.asInstanceOf[DoubleWritable].get()).toSeq.asJava } } @@ -76,9 +77,8 @@ private[python] class TestOutputKeyConverter extends Converter[Any, Any] { } private[python] class TestOutputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): DoubleWritable = { - new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().iterator().next()) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index b7e72d4d0ed0..8b3be0da2c8c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -113,6 +113,7 @@ private[spark] object RBackend extends Logging { val dos = new DataOutputStream(new FileOutputStream(f)) dos.writeInt(boundPort) dos.writeInt(listenPort) + SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 14dac4ed28ce..0095548c463c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -118,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -129,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) + + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } - val obj = ctor.newInstance(args : _*) + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -159,39 +178,80 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1cf2824f862e..7509b3d3f44b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} +import java.util.Arrays import java.util.{Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try @@ -365,11 +366,11 @@ private[r] object RRDD { sparkConf.setIfMissing("spark.master", "local") } - for ((name, value) <- sparkEnvirMap) { - sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkEnvirMap.asScala) { + sparkConf.set(name.toString, value.toString) } - for ((name, value) <- sparkExecutorEnvMap) { - sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkExecutorEnvMap.asScala) { + sparkConf.setExecutorEnv(name.toString, value.toString) } val jsc = new JavaSparkContext(sparkConf) @@ -391,17 +392,22 @@ private[r] object RRDD { } private def createRProcess(port: Int, script: String): BufferedStreamThread = { - val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + val sparkConf = SparkEnv.get.conf + var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") + rCommand = sparkConf.get("spark.r.command", rCommand) + val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir + "/SparkR/worker/" + script - val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) // and confuses worker script which tries to load a non-existent file pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index d53abd3408c5..16157414fd12 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -18,48 +18,82 @@ package org.apache.spark.api.r import java.io.File +import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} private[spark] object RUtils { + // Local path where R binary packages built from R source code contained in the spark + // packages specified with "--packages" or "--jars" command line option reside. + var rPackages: Option[String] = None + /** * Get the SparkR package path in the local spark distribution. */ def localSparkRPackagePath: Option[String] = { - val sparkHome = sys.env.get("SPARK_HOME") + val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.test.home")) sparkHome.map( Seq(_, "R", "lib").mkString(File.separator) ) } /** - * Get the SparkR package path in various deployment modes. + * Get the list of paths for R packages in various deployment modes, of which the first + * path is for the SparkR package itself. The second path is for R packages built as + * part of Spark Packages, if any exist. Spark Packages can be provided through the + * "--packages" or "--jars" command line options. + * * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` * and environment variable `SPARK_HOME` are set. */ - def sparkRPackagePath(isDriver: Boolean): String = { + def sparkRPackagePath(isDriver: Boolean): Seq[String] = { val (master, deployMode) = if (isDriver) { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } - val isYarnCluster = master.contains("yarn") && deployMode == "cluster" - val isYarnClient = master.contains("yarn") && deployMode == "client" + val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master != null && master.contains("yarn") && deployMode == "client" // In YARN mode, the SparkR package is distributed as an archive symbolically - // linked to the "sparkr" file in the current directory. Note that this does not apply - // to the driver in client mode because it is run outside of the cluster. + // linked to the "sparkr" file in the current directory and additional R packages + // are distributed as an archive symbolically linked to the "rpkg" file in the + // current directory. + // + // Note that this does not apply to the driver in client mode because it is run + // outside of the cluster. if (isYarnCluster || (isYarnClient && !isDriver)) { - new File("sparkr").getAbsolutePath + val sparkRPkgPath = new File("sparkr").getAbsolutePath + val rPkgPath = new File("rpkg") + if (rPkgPath.exists()) { + Seq(sparkRPkgPath, rPkgPath.getAbsolutePath) + } else { + Seq(sparkRPkgPath) + } } else { // Otherwise, assume the package is local // TODO: support this for Mesos - localSparkRPackagePath.getOrElse { - throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + val sparkRPkgPath = localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + if (!rPackages.isEmpty) { + Seq(sparkRPkgPath, rPackages.get) + } else { + Seq(sparkRPkgPath) } } } + + /** Check if R is installed before running tests that use R commands. */ + def isRInstalled: Boolean = { + try { + val builder = new ProcessBuilder(Arrays.asList("R", "--version")) + builder.start().waitFor() == 0 + } catch { + case e: Exception => false + } + } } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index d5b4260bf452..da126bac7ad1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -20,12 +20,21 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { + type ReadObject = (DataInputStream, Char) => Object + type WriteObject = (DataOutputStream, Object) => Boolean + + var sqlSerDe: (ReadObject, WriteObject) = _ + + def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { + this.sqlSerDe = sqlSerDe + } // Type mapping from R to Java // @@ -62,11 +71,22 @@ private[spark] object SerDe { case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) + case 'a' => readArray(dis) case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) case 'j' => JVMObjectTracker.getObject(readString(dis)) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + val obj = (sqlSerDe._1)(dis, dataType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + obj + } + } } } @@ -140,7 +160,8 @@ private[spark] object SerDe { (0 until len).map(_ => readString(in)).toArray } - def readList(dis: DataInputStream): Array[_] = { + // All elements of an array must be of the same type + def readArray(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -149,23 +170,44 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) - case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + case 'a' => + val len = readInt(dis) + (0 until len).map(_ => readArray(dis)).toArray + case 'l' => + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + val len = readInt(dis) + (0 until len).map { _ => + val obj = (sqlSerDe._1)(dis, arrType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + obj + } + }.toArray + } } } + // Each element of a list can be of different type. They are all represented + // as Object on JVM side + def readList(dis: DataInputStream): Array[Object] = { + val len = readInt(dis) + (0 until len).map(_ => readObject(dis)).toArray + } + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { - val keysType = readObjectType(in) - val keysLen = readInt(in) - val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - - val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => { - val valueType = readObjectType(in) - readTypedObject(in, valueType) - }) - mapAsJavaMap(keys.zip(values).toMap) + // Keys is an array of String + val keys = readArray(in).asInstanceOf[Array[Object]] + val values = readList(in) + + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() } @@ -181,6 +223,7 @@ private[spark] object SerDe { // Boolean -> logical // Float -> double // Double -> double + // Decimal -> double // Long -> double // Array[Byte] -> raw // Date -> Date @@ -199,78 +242,140 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.String" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + writeString(dos, v.toString) + case v: java.lang.String => + writeType(dos, "character") + writeString(dos, v) + case v: java.lang.Long => + writeType(dos, "double") + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "float" | "java.lang.Float" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "double" | "java.lang.Double" => + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + writeDouble(dos, v) + case v: java.lang.Byte => + writeType(dos, "integer") + writeInt(dos, v.toInt) + case v: java.lang.Short => + writeType(dos, "integer") + writeInt(dos, v.toInt) + case v: java.lang.Integer => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) - case "[B" => - writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float + writeTime(dos, v) // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) - case "[I" => - writeType(dos, "list") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => - writeType(dos, "list") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[D" => - writeType(dos, "list") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => - writeType(dos, "list") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of primitive types + + // Special handling for byte array + case v: Array[Byte] => + writeType(dos, "raw") + writeBytes(dos, v) + + case v: Array[Char] => + writeType(dos, "array") + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => + writeType(dos, "array") + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => + writeType(dos, "array") + writeIntArr(dos, v) + case v: Array[Long] => + writeType(dos, "array") + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => + writeType(dos, "array") + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => + writeType(dos, "array") + writeDoubleArr(dos, v) + case v: Array[Boolean] => + writeType(dos, "array") + writeBooleanArr(dos, v) + + // Array of objects, null objects use "void" type + case v: Array[Object] => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) + + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + + case _ => + if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { writeType(dos, "jobj") writeJObj(dos, value) } @@ -303,12 +408,11 @@ private[spark] object SerDe { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } - // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { - val len = value.length - out.writeInt(len + 1) // For the \0 - out.writeBytes(value) - out.writeByte(0) + val utf8 = value.getBytes("UTF-8") + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) } def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { @@ -345,11 +449,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index a0c9b5e63c74..7e3764d802fe 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer -import scala.collection.JavaConversions.asJavaEnumeration +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -210,7 +210,7 @@ private object TorrentBroadcast extends Logging { compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index ae99432f5ce8..78bbd5c03f4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -19,30 +19,17 @@ package org.apache.spark.deploy import java.net.URI -private[spark] class ApplicationDescription( - val name: String, - val maxCores: Option[Int], - val memoryPerExecutorMB: Int, - val command: Command, - var appUiUrl: String, - val eventLogDir: Option[URI] = None, +private[spark] case class ApplicationDescription( + name: String, + maxCores: Option[Int], + memoryPerExecutorMB: Int, + command: Command, + appUiUrl: String, + eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) - val eventLogCodec: Option[String] = None, - val coresPerExecutor: Option[Int] = None) - extends Serializable { - - val user = System.getProperty("user.name", "") - - def copy( - name: String = name, - maxCores: Option[Int] = maxCores, - memoryPerExecutorMB: Int = memoryPerExecutorMB, - command: Command = command, - appUiUrl: String = appUiUrl, - eventLogDir: Option[URI] = eventLogDir, - eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = - new ApplicationDescription( - name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) + eventLogCodec: Option[String] = None, + coresPerExecutor: Option[Int] = None, + user: String = System.getProperty("user.name", "")) { override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index d8084a57658a..3feb7cea593e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -69,9 +69,14 @@ private[deploy] object DeployMessages { // Master to Worker + sealed trait RegisterWorkerResponse + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage + with RegisterWorkerResponse + + case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse - case class RegisterWorkerFailed(message: String) extends DeployMessage + case object MasterInStandby extends DeployMessage with RegisterWorkerResponse case class ReconnectWorker(masterUrl: String) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala index 659fb434a80f..1f5626ab5a89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -17,21 +17,12 @@ package org.apache.spark.deploy -private[deploy] class DriverDescription( - val jarUrl: String, - val mem: Int, - val cores: Int, - val supervise: Boolean, - val command: Command) - extends Serializable { - - def copy( - jarUrl: String = jarUrl, - mem: Int = mem, - cores: Int = cores, - supervise: Boolean = supervise, - command: Command = command): DriverDescription = - new DriverDescription(jarUrl, mem, cores, supervise, command) +private[deploy] case class DriverDescription( + jarUrl: String, + mem: Int, + cores: Int, + supervise: Boolean, + command: Command) { override def toString: String = s"DriverDescription (${command.mainClass})" } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index efa88c62e1f5..69c98e28931d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 20a9faa1784b..7fc96e4f764b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,16 +19,16 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -45,15 +45,17 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + private val transportContext: TransportContext = + new TransportContext(transportConf, blockHandler, true) private var server: TransportServer = _ /** Create a new shuffle block handler. Factored out for subclasses to override. */ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { - new ExternalShuffleBlockHandler(conf) + new ExternalShuffleBlockHandler(conf, null) } /** Starts the external shuffle service if the user has configured us to. */ @@ -67,13 +69,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - val bootstraps = + val bootstraps: Seq[TransportServerBootstrap] = if (useSasl) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { Nil } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(port, bootstraps.asJava) } /** Clean up all shuffle files associated with an application that has exited. */ @@ -116,19 +118,13 @@ object ExternalShuffleService extends Logging { server = newShuffleService(sparkConf, securityManager) server.start() - installShutdownHook() + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } // keep running until the process is terminated barrier.await() } - - private def installShutdownHook(): Unit = { - Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { - override def run() { - logInfo("Shutting down shuffle service.") - server.stop() - barrier.countDown() - } - }) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index ccffb3665298..220b20bf7cbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -45,7 +45,7 @@ private[deploy] object JsonProtocol { ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ + ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 53356addf6ed..5bb62d37d637 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -73,12 +73,10 @@ class LocalSparkCluster( def stop() { logInfo("Shutting down local Spark cluster.") // Stop the workers before the master so they don't get upset that it disconnected - // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! - // This is unfortunate, but for now we just comment it out. workerRpcEnvs.foreach(_.shutdown()) - // workerActorSystems.foreach(_.awaitTermination()) masterRpcEnvs.foreach(_.shutdown()) - // masterActorSystems.foreach(_.awaitTermination()) + workerRpcEnvs.foreach(_.awaitTermination()) + masterRpcEnvs.foreach(_.awaitTermination()) masterRpcEnvs.clear() workerRpcEnvs.clear() } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index c2ed43a5397d..d85327603f64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -21,9 +21,10 @@ import java.net.URI import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try +import org.apache.spark.SparkUserAppException import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -46,7 +47,20 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such val gatewayServer = new py4j.GatewayServer(null, 0) - gatewayServer.start() + val thread = new Thread(new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions { + gatewayServer.start() + } + }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() + + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -57,18 +71,25 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize - val process = builder.start() + try { + val process = builder.start() - new RedirectThread(process.getInputStream, System.out, "redirect output").start() + new RedirectThread(process.getInputStream, System.out, "redirect output").start() - System.exit(process.waitFor()) + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkUserAppException(exitCode) + } + } finally { + gatewayServer.shutdown() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala new file mode 100644 index 000000000000..d46dc87a92c9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.jar.JarFile +import java.util.logging.Level +import java.util.zip.{ZipEntry, ZipOutputStream} + +import scala.collection.JavaConverters._ + +import com.google.common.io.{ByteStreams, Files} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.api.r.RUtils +import org.apache.spark.util.{RedirectThread, Utils} + +private[deploy] object RPackageUtils extends Logging { + + /** The key in the MANIFEST.mf that we look for, in case a jar contains R code. */ + private final val hasRPackage = "Spark-HasRPackage" + + /** Base of the shell command used in order to install R packages. */ + private final val baseInstallCmd = Seq("R", "CMD", "INSTALL", "-l") + + /** R source code should exist under R/pkg in a jar. */ + private final val RJarEntries = "R/pkg" + + /** Documentation on how the R source file layout should be in the jar. */ + private[deploy] final val RJarDoc = + s"""In order for Spark to build R packages that are parts of Spark Packages, there are a few + |requirements. The R source code must be shipped in a jar, with additional Java/Scala + |classes. The jar must be in the following format: + | 1- The Manifest (META-INF/MANIFEST.mf) must contain the key-value: $hasRPackage: true + | 2- The standard R package layout must be preserved under R/pkg/ inside the jar. More + | information on the standard R package layout can be found in: + | http://cran.r-project.org/doc/contrib/Leisch-CreatingPackages.pdf + | An example layout is given below. After running `jar tf $$JAR_FILE | sort`: + | + |META-INF/MANIFEST.MF + |R/ + |R/pkg/ + |R/pkg/DESCRIPTION + |R/pkg/NAMESPACE + |R/pkg/R/ + |R/pkg/R/myRcode.R + |org/ + |org/apache/ + |... + """.stripMargin.trim + + /** Internal method for logging. We log to a printStream in tests, for debugging purposes. */ + private def print( + msg: String, + printStream: PrintStream, + level: Level = Level.FINE, + e: Throwable = null): Unit = { + if (printStream != null) { + // scalastyle:off println + printStream.println(msg) + // scalastyle:on println + if (e != null) { + e.printStackTrace(printStream) + } + } else { + level match { + case Level.INFO => logInfo(msg) + case Level.WARNING => logWarning(msg) + case Level.SEVERE => logError(msg, e) + case _ => logDebug(msg) + } + } + } + + /** + * Checks the manifest of the Jar whether there is any R source code bundled with it. + * Exposed for testing. + */ + private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + val manifest = jar.getManifest.getMainAttributes + manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" + } + + /** + * Runs the standard R package installation code to build the R package from source. + * Multiple runs don't cause problems. + */ + private def rPackageBuilder( + dir: File, + printStream: PrintStream, + verbose: Boolean, + libDir: String): Boolean = { + // this code should be always running on the driver. + val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator) + val installCmd = baseInstallCmd ++ Seq(libDir, pathToPkg) + if (verbose) { + print(s"Building R package with the command: $installCmd", printStream) + } + try { + val builder = new ProcessBuilder(installCmd.asJava) + builder.redirectErrorStream(true) + + // Put the SparkR package directory into R library search paths in case this R package + // may depend on SparkR. + val env = builder.environment() + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) + env.put("R_PROFILE_USER", + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + + val process = builder.start() + new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start() + process.waitFor() == 0 + } catch { + case e: Throwable => + print("Failed to build R package.", printStream, Level.SEVERE, e) + false + } + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private def extractRFolder(jar: JarFile, printStream: PrintStream, verbose: Boolean): File = { + val tempDir = Utils.createTempDir(null) + val jarEntries = jar.entries() + while (jarEntries.hasMoreElements) { + val entry = jarEntries.nextElement() + val entryRIndex = entry.getName.indexOf(RJarEntries) + if (entryRIndex > -1) { + val entryPath = entry.getName.substring(entryRIndex) + if (entry.isDirectory) { + val dir = new File(tempDir, entryPath) + if (verbose) { + print(s"Creating directory: $dir", printStream) + } + dir.mkdirs + } else { + val inStream = jar.getInputStream(entry) + val outPath = new File(tempDir, entryPath) + Files.createParentDirs(outPath) + val outStream = new FileOutputStream(outPath) + if (verbose) { + print(s"Extracting $entry to $outPath", printStream) + } + Utils.copyStream(inStream, outStream, closeStreams = true) + } + } + } + tempDir + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private[deploy] def checkAndBuildRPackage( + jars: String, + printStream: PrintStream = null, + verbose: Boolean = false): Unit = { + jars.split(",").foreach { jarPath => + val file = new File(Utils.resolveURI(jarPath)) + if (file.exists()) { + val jar = new JarFile(file) + if (checkManifestForR(jar)) { + print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) + val rSource = extractRFolder(jar, printStream, verbose) + if (RUtils.rPackages.isEmpty) { + RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) + } + try { + if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { + print(s"ERROR: Failed to build R package in $file.", printStream) + print(RJarDoc, printStream) + } + } finally { // clean up + if (!rSource.delete()) { + logWarning(s"Error deleting ${rSource.getPath()}") + } + } + } else { + if (verbose) { + print(s"$file doesn't contain R source code, skipping...", printStream) + } + } + } else { + print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING) + } + } + } + + private def listFilesRecursively(dir: File, excludePatterns: Seq[String]): Set[File] = { + if (!dir.exists()) { + Set.empty[File] + } else { + if (dir.isDirectory) { + val subDir = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + !excludePatterns.map(name.contains).reduce(_ || _) // exclude files with given pattern + } + }) + subDir.flatMap(listFilesRecursively(_, excludePatterns)).toSet + } else { + Set(dir) + } + } + } + + /** Zips all the R libraries built for distribution to the cluster. */ + private[deploy] def zipRLibraries(dir: File, name: String): File = { + val filesToBundle = listFilesRecursively(dir, Seq(".zip")) + // create a zip file from scratch, do not append to existing file. + val zipFile = new File(dir, name) + if (!zipFile.delete()) { + logWarning(s"Error deleting ${zipFile.getPath()}") + } + val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) + try { + filesToBundle.foreach { file => + // get the relative paths for proper naming in the zip file + val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "") + val fis = new FileInputStream(file) + val zipEntry = new ZipEntry(relPath) + zipOutputStream.putNextEntry(zipEntry) + ByteStreams.copy(fis, zipOutputStream) + zipOutputStream.closeEntry() + fis.close() + } + } finally { + zipOutputStream.close() + } + zipFile + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index c0cab22fa825..661f7317c674 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -20,11 +20,12 @@ package org.apache.spark.deploy import java.io._ import java.util.concurrent.{Semaphore, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.util.RedirectThread /** @@ -39,7 +40,16 @@ object RRunner { // Time to wait for SparkR backend to initialize in seconds val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt - val rCommand = "Rscript" + val rCommand = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + var cmd = sys.props.getOrElse("spark.sparkr.r.command", "Rscript") + cmd = sys.props.getOrElse("spark.r.command", cmd) + if (sys.props.getOrElse("spark.submit.deployMode", "client") == "client") { + cmd = sys.props.getOrElse("spark.r.driver.command", cmd) + } + cmd + } // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode @@ -68,13 +78,14 @@ object RRunner { if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { // Launch R val returnCode = try { - val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) - env.put("SPARKR_PACKAGE_DIR", rPackageDir) + // Put the R package directories into an env variable of comma-separated paths + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", - Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() @@ -84,12 +95,15 @@ object RRunner { } finally { sparkRBackend.close() } - System.exit(returnCode) + if (returnCode != 0) { + throw new SparkUserAppException(returnCode) + } } else { + val errorMessage = s"SparkR backend did not initialize in $backendTimeout seconds" // scalastyle:off println - System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.err.println(errorMessage) // scalastyle:on println - System.exit(-1) + throw new SparkException(errorMessage) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index b8d399354022..8d5e716e6aea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry @@ -57,7 +57,7 @@ private[spark] object SparkCuratorUtil extends Logging { def deleteRecursive(zk: CuratorFramework, path: String) { if (zk.checkExists().forPath(path) != null) { - for (child <- zk.getChildren.forPath(path)) { + for (child <- zk.getChildren.forPath(path).asScala) { zk.delete().forPath(path + "/" + child) } zk.delete().forPath(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e06b06e06fb4..59e90564b351 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.util.{Arrays, Comparator} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal @@ -34,6 +34,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -69,12 +71,12 @@ class SparkHadoopUtil extends Logging { } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens()) { + for (token <- source.getTokens.asScala) { dest.addToken(token) } } - @Deprecated + @deprecated("use newConfiguration with SparkConf argument", "1.2.0") def newConfiguration(): Configuration = newConfiguration(null) /** @@ -90,10 +92,15 @@ class SparkHadoopUtil extends Logging { // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + + hadoopConf.set("fs.s3.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3a.access.key", keyId) + hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3a.secret.key", accessKey) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => @@ -173,8 +180,8 @@ class SparkHadoopUtil extends Logging { } private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - val stats = FileSystem.getAllStatistics() - stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + FileSystem.getAllStatistics.asScala.map( + Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { @@ -190,10 +197,26 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } + /** + * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly + * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes + * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ + * while it's interface in Hadoop 2.+. + */ + def getTaskAttemptIDFromTaskAttemptContext( + context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig + val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig + method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] + } + /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of @@ -292,12 +315,13 @@ class SparkHadoopUtil extends Logging { val renewalInterval = sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) - credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + credentials.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .map { t => - val identifier = new DelegationTokenIdentifier() - identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) - (identifier.getIssueDate + fraction * renewalInterval).toLong - now - }.foldLeft(0L)(math.max) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) } @@ -366,20 +390,13 @@ class SparkHadoopUtil extends Logging { object SparkHadoopUtil { - private val hadoop = { - val yarnMode = java.lang.Boolean.valueOf( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if (yarnMode) { - try { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .newInstance() - .asInstanceOf[SparkHadoopUtil] - } catch { - case e: Exception => throw new SparkException("Unable to load YARN support", e) - } - } else { - new SparkHadoopUtil - } + private lazy val hadoop = new SparkHadoopUtil + private lazy val yarn = try { + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .newInstance() + .asInstanceOf[SparkHadoopUtil] + } catch { + case e: Exception => throw new SparkException("Unable to load YARN support", e) } val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" @@ -387,6 +404,13 @@ object SparkHadoopUtil { val SPARK_YARN_CREDS_COUNTER_DELIM = "-" def get: SparkHadoopUtil = { - hadoop + // Check each time to support changing to/from YARN + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (yarnMode) { + yarn + } else { + hadoop + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0b39ee8fe3ba..52d3ab34c178 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy @@ -37,8 +38,9 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} + +import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION} import org.apache.spark.api.r.RUtils -import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -81,6 +83,7 @@ object SparkSubmit { private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -275,24 +278,27 @@ object SparkSubmit { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val resolvedMavenCoordinates = - SparkSubmitUtils.resolveMavenCoordinates( - args.packages, Option(args.repositories), Option(args.ivyRepoPath)) - if (!resolvedMavenCoordinates.trim.isEmpty) { - if (args.jars == null || args.jars.trim.isEmpty) { - args.jars = resolvedMavenCoordinates + val exclusions: Seq[String] = + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") } else { - args.jars += s",$resolvedMavenCoordinates" + Nil } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, + Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { - if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { - args.pyFiles = resolvedMavenCoordinates - } else { - args.pyFiles += s",$resolvedMavenCoordinates" - } + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } + // install any R packages that may have been passed through --jars or --packages. + // Spark Packages may contain R source code inside the jar. + if (args.isR && !StringUtils.isBlank(args.jars)) { + RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local if (args.isPython && !isYarnCluster) { @@ -314,8 +320,8 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + case (MESOS, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + @@ -323,6 +329,8 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isR => printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") + case (LOCAL, CLUSTER) => + printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -355,21 +363,46 @@ object SparkSubmit { } } - // In YARN mode for an R app, add the SparkR package archive to archives - // that can be distributed with the job + // In YARN mode for an R app, add the SparkR package archive and the R package + // archive containing all of the built R libraries to archives so that they can + // be distributed with the job if (args.isR && clusterManager == YARN) { - val rPackagePath = RUtils.localSparkRPackagePath - if (rPackagePath.isEmpty) { + val sparkRPackagePath = RUtils.localSparkRPackagePath + if (sparkRPackagePath.isEmpty) { printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") } - val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE) - if (!rPackageFile.exists()) { + val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!sparkRPackageFile.exists()) { printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } - val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString + // Distribute the SparkR package. // Assigns a symbol link name "sparkr" to the shipped package. - args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr") + + // Distribute the R package archive containing all the built R packages. + if (!RUtils.rPackages.isEmpty) { + val rPackageFile = + RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit("Failed to zip all the built R packages.") + } + + val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString + // Assigns a symbol link name "rpkg" to the shipped package. + args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg") + } + } + + // TODO: Support distributing R packages with standalone cluster + if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + } + + // TODO: Support SparkR with mesos cluster + if (args.isR && clusterManager == MESOS) { + printErrorAndExit("SparkR is not supported for Mesos cluster.") } // If we're running a R app, set the main class to our specific R runner @@ -416,7 +449,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -427,7 +461,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), @@ -512,9 +545,24 @@ object SparkSubmit { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } + } + + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL) { if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when the keytab is specified") - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + require(args.keytab != null, "Keytab must be specified when principal is specified") + if (!new File(args.keytab).exists()) { + throw new SparkException(s"Keytab file: ${args.keytab} does not exist") + } else { + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } } @@ -545,7 +593,15 @@ object SparkSubmit { if (isMesosCluster) { assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" - childArgs += (args.primaryResource, args.mainClass) + if (args.isPython) { + // Second argument is main class + childArgs += (args.primaryResource, "") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } + } else { + childArgs += (args.primaryResource, args.mainClass) + } if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -641,6 +697,15 @@ object SparkSubmit { // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + case e: NoClassDefFoundError => + e.printStackTrace(printStream) + if (e.getMessage.contains("org/apache/hadoop/hive")) { + // scalastyle:off println + printStream.println(s"Failed to load hive class.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println + } + System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } // SPARK-4170 @@ -666,7 +731,13 @@ object SparkSubmit { mainMethod.invoke(null, childArgs.toArray) } catch { case t: Throwable => - throw findCause(t) + findCause(t) match { + case SparkUserAppException(exitCode) => + System.exit(exitCode) + + case t: Throwable => + throw t + } } } @@ -736,7 +807,7 @@ object SparkSubmit { * no files, into a single comma-separated string. */ private def mergeFileLists(lists: String*): String = { - val merged = lists.filter(_ != null) + val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged @@ -938,7 +1009,7 @@ private[spark] object SparkSubmitUtils { // are supplied to spark-submit val alternateIvyCache = ivyPath.getOrElse("") val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { + if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) @@ -988,11 +1059,9 @@ private[spark] object SparkSubmitUtils { addExclusionRules(ivySettings, ivyConfName, md) // add all supplied maven artifacts as dependencies addDependenciesToIvy(md, artifacts, ivyConfName) - exclusions.foreach { e => md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) } - // resolve dependencies val rr: ResolveReport = ivy.resolve(md, resolveOptions) if (rr.hasError) { @@ -1010,7 +1079,7 @@ private[spark] object SparkSubmitUtils { } } - private def createExclusion( + private[deploy] def createExclusion( coords: String, ivySettings: IvySettings, ivyConfName: String): ExcludeRule = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b3710073e330..915ef81b4eae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null @@ -93,7 +94,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set parameters from command line arguments try { - parse(args.toList) + parse(args.asJava) } catch { case e: IllegalArgumentException => SparkSubmit.printErrorAndExit(e.getMessage()) @@ -172,7 +173,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull - deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull + packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull + packagesExclusions = Option(packagesExclusions) + .orElse(sparkProperties.get("spark.jars.excludes")).orNull + deployMode = Option(deployMode) + .orElse(sparkProperties.get("spark.submit.deployMode")) + .orElse(env.get("DEPLOY_MODE")) + .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull @@ -299,6 +306,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | childArgs [${childArgs.mkString(" ")}] | jars $jars | packages $packages + | packagesExclusions $packagesExclusions | repositories $repositories | verbose $verbose | @@ -391,6 +399,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PACKAGES => packages = value + case PACKAGES_EXCLUDE => + packagesExclusions = value + case REPOSITORIES => repositories = value @@ -450,7 +461,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } override protected def handleExtraArgs(extra: JList[String]): Unit = { - childArgs ++= extra + childArgs ++= extra.asScala } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { @@ -482,6 +493,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | maven repo, then maven central and any additional remote | repositories given by --repositories. The format for the | coordinates should be groupId:artifactId:version. + | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while + | resolving the dependencies provided in --packages to avoid + | dependency conflicts. | --repositories Comma-separated list of additional remote repositories to | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place @@ -600,5 +614,4 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 7576a2985ee7..1e2f469214b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.client import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.util.control.NonFatal @@ -49,9 +50,9 @@ private[spark] class AppClient( private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var endpoint: RpcEndpointRef = null - private var appId: String = null - @volatile private var registered = false + private val endpoint = new AtomicReference[RpcEndpointRef] + private val appId = new AtomicReference[String] + private val registered = new AtomicBoolean(false) private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { @@ -59,24 +60,28 @@ private[spark] class AppClient( private var master: Option[RpcEndpointRef] = None // To avoid calling listener.disconnected() multiple times private var alreadyDisconnected = false - @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times - @volatile private var registerMasterFutures: Array[JFuture[_]] = null - @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + // To avoid calling listener.dead() multiple times + private val alreadyDead = new AtomicBoolean(false) + private val registerMasterFutures = new AtomicReference[Array[JFuture[_]]] + private val registrationRetryTimer = new AtomicReference[JScheduledFuture[_]] // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "appclient-register-master-threadpool", + masterRpcAddresses.length // Make sure we can register with all masters at the same time + ) // A scheduled executor for scheduling the registration actions private val registrationRetryThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") + // A thread pool to perform receive then reply actions in a thread so as not to block the + // event loop. + private val askAndReplyThreadPool = + ThreadUtils.newDaemonCachedThreadPool("appclient-receive-and-reply-threadpool") + override def onStart(): Unit = { try { registerWithMaster(1) @@ -95,7 +100,7 @@ private[spark] class AppClient( for (masterAddress <- masterRpcAddresses) yield { registerMasterThreadPool.submit(new Runnable { override def run(): Unit = try { - if (registered) { + if (registered.get) { return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") @@ -118,22 +123,22 @@ private[spark] class AppClient( * nthRetry means this is the nth attempt to register with master. */ private def registerWithMaster(nthRetry: Int) { - registerMasterFutures = tryRegisterAllMasters() - registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + registerMasterFutures.set(tryRegisterAllMasters()) + registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = { Utils.tryOrExit { - if (registered) { - registerMasterFutures.foreach(_.cancel(true)) + if (registered.get) { + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerWithMaster(nthRetry + 1) } } } - }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) } /** @@ -158,10 +163,10 @@ private[spark] class AppClient( // RegisteredApplications due to an unstable network. // 2. Receive multiple RegisteredApplication from different masters because the master is // changing. - appId = appId_ - registered = true + appId.set(appId_) + registered.set(true) master = Some(masterRef) - listener.connected(appId) + listener.connected(appId.get) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) @@ -171,9 +176,6 @@ private[spark] class AppClient( val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not - // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -188,19 +190,19 @@ private[spark] class AppClient( logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) alreadyDisconnected = false - masterRef.send(MasterChangeAcknowledged(appId)) + masterRef.send(MasterChangeAcknowledged(appId.get)) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopAppClient => markDead("Application has been stopped.") - sendToMaster(UnregisterApplication(appId)) + sendToMaster(UnregisterApplication(appId.get)) context.reply(true) stop() case r: RequestExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](r)) + case Some(m) => askAndReplyAsync(m, context, r) case None => logWarning("Attempted to request executors before registering with Master.") context.reply(false) @@ -208,13 +210,32 @@ private[spark] class AppClient( case k: KillExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](k)) + case Some(m) => askAndReplyAsync(m, context, k) case None => logWarning("Attempted to kill executors before registering with Master.") context.reply(false) } } + private def askAndReplyAsync[T]( + endpointRef: RpcEndpointRef, + context: RpcCallContext, + msg: T): Unit = { + // Create a thread to ask a message and reply with the result. Allow thread to be + // interrupted during shutdown, otherwise context must be notified of NonFatal errors. + askAndReplyThreadPool.execute(new Runnable { + override def run(): Unit = { + try { + context.reply(endpointRef.askWithRetry[Boolean](msg)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(t) => + context.sendFailure(t) + } + } + }) + } + override def onDisconnected(address: RpcAddress): Unit = { if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") @@ -239,38 +260,39 @@ private[spark] class AppClient( } def markDead(reason: String) { - if (!alreadyDead) { + if (!alreadyDead.get) { listener.dead(reason) - alreadyDead = true + alreadyDead.set(true) } } override def onStop(): Unit = { - if (registrationRetryTimer != null) { - registrationRetryTimer.cancel(true) + if (registrationRetryTimer.get != null) { + registrationRetryTimer.get.cancel(true) } registrationRetryThread.shutdownNow() - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() + askAndReplyThreadPool.shutdownNow() } } def start() { - // Just launch an actor; it will call back into the listener. - endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) + // Just launch an rpcEndpoint; it will call back into the listener. + endpoint.set(rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))) } def stop() { - if (endpoint != null) { + if (endpoint.get != null) { try { val timeout = RpcUtils.askRpcTimeout(conf) - timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) + timeout.awaitResult(endpoint.get.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - endpoint = null + endpoint.set(null) } } @@ -281,8 +303,8 @@ private[spark] class AppClient( * @return whether the request is acknowledged. */ def requestTotalExecutors(requestedTotal: Int): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal)) } else { logWarning("Attempted to request executors before driver fully initialized.") false @@ -294,8 +316,8 @@ private[spark] class AppClient( * @return whether the kill request is acknowledged. */ def killExecutors(executorIds: Seq[String]): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds)) } else { logWarning("Attempted to kill executors before driver fully initialized.") false diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 1c79089303e3..adb3f0225802 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -48,8 +48,9 @@ private[spark] object TestClient { val url = args(0) val conf = new SparkConf val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val executorClassname = TestExecutor.getClass.getCanonicalName.stripSuffix("$") val desc = new ApplicationDescription("TestClient", Some(1), 512, - Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") + Command(executorClassname, Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e3060ac3fa1a..718efc4f3bd5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -26,7 +27,8 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.fs.permission.AccessControlException +import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil @@ -51,6 +53,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val NOT_STARTED = "" + // Interval between safemode checks. + private val SAFEMODE_CHECK_INTERVAL_S = conf.getTimeAsSeconds( + "spark.history.fs.safemodeCheck.interval", "5s") + // Interval between each check for event log updates private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") @@ -73,7 +79,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -106,9 +112,52 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - initialize() + // Conf option used for testing the initialization code. + val initThread = initialize() + + private[history] def initialize(): Thread = { + if (!isFsInSafeMode()) { + startPolling() + null + } else { + startSafeModeCheckThread(None) + } + } - private def initialize(): Unit = { + private[history] def startSafeModeCheckThread( + errorHandler: Option[Thread.UncaughtExceptionHandler]): Thread = { + // Cannot probe anything while the FS is in safe mode, so spawn a new thread that will wait + // for the FS to leave safe mode before enabling polling. This allows the main history server + // UI to be shown (so that the user can see the HDFS status). + val initThread = new Thread(new Runnable() { + override def run(): Unit = { + try { + while (isFsInSafeMode()) { + logInfo("HDFS is still in safe mode. Waiting...") + val deadline = clock.getTimeMillis() + + TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S) + clock.waitTillTime(deadline) + } + startPolling() + } catch { + case _: InterruptedException => + } + } + }) + initThread.setDaemon(true) + initThread.setName(s"${getClass().getSimpleName()}-init") + initThread.setUncaughtExceptionHandler(errorHandler.getOrElse( + new Thread.UncaughtExceptionHandler() { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + logError("Error initializing FsHistoryProvider.", e) + System.exit(1) + } + })) + initThread.start() + initThread + } + + private def startPolling(): Unit = { // Validate the log directory. val path = new Path(logDir) if (!fs.exists(path)) { @@ -126,11 +175,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } } @@ -145,16 +194,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up @@ -170,7 +218,21 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString) + override def getConfig(): Map[String, String] = { + val safeMode = if (isFsInSafeMode()) { + Map("HDFS State" -> "In safe mode, application logs not available.") + } else { + Map() + } + Map("Event log directory" -> logDir.toString) ++ safeMode + } + + override def stop(): Unit = { + if (initThread != null && initThread.isAlive()) { + initThread.interrupt() + initThread.join() + } + } /** * Builds the application list based on the current contents of the log directory. @@ -179,15 +241,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -204,18 +265,51 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) mod1 >= mod2 } - logInfos.sliding(20, 20).foreach { batch => - replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(batch) - }) - } + logInfos.grouped(20) + .map { batch => + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(batch) + }) + } + .foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } + } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + if (!fs.delete(path)) { + logWarning(s"Error deleting ${path}") + } + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], @@ -272,9 +366,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") @@ -375,7 +469,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val path = new Path(logDir, attempt.logPath) if (fs.exists(path)) { - fs.delete(path, true) + if (!fs.delete(path, true)) { + logWarning(s"Error deleting ${path}") + } } } catch { case e: AccessControlException => @@ -551,6 +647,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Checks whether HDFS is in safe mode. The API is slightly different between hadoop 1 and 2, + * so we have to resort to ugly reflection (as usual...). + * + * Note that DistributedFileSystem is a `@LimitedPrivate` class, which for all practical reasons + * makes it more public than not. + */ + private[history] def isFsInSafeMode(): Boolean = fs match { + case dfs: DistributedFileSystem => + isFsInSafeMode(dfs) + case _ => + false + } + + // For testing. + private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { + val hadoop1Class = "org.apache.hadoop.hdfs.protocol.FSConstants$SafeModeAction" + val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" + val actionClass: Class[_] = + try { + getClass().getClassLoader().loadClass(hadoop2Class) + } catch { + case _: ClassNotFoundException => + getClass().getClassLoader().loadClass(hadoop1Class) + } + + val action = actionClass.getField("SAFEMODE_GET").get(null) + val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) + method.invoke(dfs, action).asInstanceOf[Boolean] + } + } private[history] object FsHistoryProvider { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba124..642d71b18c9e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } @@ -158,7 +161,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") info: ApplicationHistoryInfo, attempt: ApplicationAttemptInfo, isFirst: Boolean): Seq[Node] = { - val uiAddress = HistoryServer.getAttemptURI(info.id, attempt.attemptId) + val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId)) val startTime = UIUtils.formatDate(attempt.startTime) val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" val duration = @@ -187,8 +190,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (renderAttemptIdColumn) { if (info.attempts.size > 1 && attempt.attemptId.isDefined) { - - {attempt.attemptId.get} + {attempt.attemptId.get} } else {   } @@ -215,9 +217,9 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { - "/?" + Array( + UIUtils.prependBaseUri("/?" + Array( "page=" + linkPage, "showIncomplete=" + showIncomplete - ).mkString("&") + ).mkString("&")) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a076a9c3f984..f31fef0eccc3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -103,7 +103,9 @@ class HistoryServer( // Note we don't use the UI retrieved from the cache; the cache loader above will register // the app's UI, and all we need to do is redirect the user to the same URI that was // requested, and the proper data should be served at that point. - res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + // Also, make sure that the redirect url contains the query string present in the request. + val requestURI = req.getRequestURI + Option(req.getQueryString).map("?" + _).getOrElse("") + res.sendRedirect(res.encodeRedirectURL(requestURI)) } // SPARK-5983 ensure TRACE is not supported @@ -238,7 +240,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 18265df9faa2..d03bab3820bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -30,28 +30,35 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin parse(args.toList) private def parse(args: List[String]): Unit = { - args match { - case ("--dir" | "-d") :: value :: tail => - logWarning("Setting log directory through the command line is deprecated as of " + - "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") - conf.set("spark.history.fs.logDirectory", value) - System.setProperty("spark.history.fs.logDirectory", value) - parse(tail) + if (args.length == 1) { + setLogDirectory(args.head) + } else { + args match { + case ("--dir" | "-d") :: value :: tail => + setLogDirectory(value) + parse(tail) - case ("--help" | "-h") :: tail => - printUsageAndExit(0) + case ("--help" | "-h") :: tail => + printUsageAndExit(0) - case ("--properties-file") :: value :: tail => - propertiesFile = value - parse(tail) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) - case Nil => + case Nil => - case _ => - printUsageAndExit(1) + case _ => + printUsageAndExit(1) + } } } + private def setLogDirectory(value: String): Unit = { + logWarning("Setting log directory through the command line is deprecated as of " + + "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") + conf.set("spark.history.fs.logDirectory", value) + } + // This mutates the SparkConf, so all accesses to it must be made after this line Utils.loadDefaultSparkProperties(conf, propertiesFile) @@ -62,6 +69,8 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin |Usage: HistoryServer [options] | |Options: + | DIR Deprecated; set spark.history.fs.logDirectory directly + | --dir DIR (-d DIR) Deprecated; set spark.history.fs.logDirectory directly | --properties-file FILE Path to a custom Spark properties file. | Default is conf/spark-defaults.conf. | @@ -90,3 +99,4 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin } } + diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index b40d20f9f786..7e2cf956c725 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -41,6 +41,7 @@ private[spark] class ApplicationInfo( @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ + @transient @volatile var appUIUrlAtHistoryServer: Option[String] = None // A cap on the number of executors this application can have at any given time. // By default, this is infinite. Only after the first allocation request is issued by the @@ -65,6 +66,7 @@ private[spark] class ApplicationInfo( nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] executorLimit = Integer.MAX_VALUE + appUIUrlAtHistoryServer = None } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -135,4 +137,10 @@ private[spark] class ApplicationInfo( } } + /** + * Returns the original application UI url unless there is its address at history server + * is defined + */ + def curAppUIUrl: String = appUIUrlAtHistoryServer.getOrElse(desc.appUiUrl) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index aa379d4cd61e..1aa8cd5013b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -45,7 +45,10 @@ private[master] class FileSystemPersistenceEngine( } override def unpersist(name: String): Unit = { - new File(dir + File.separator + name).delete() + val f = new File(dir + File.separator + name) + if (!f.delete()) { + logWarning(s"Error deleting ${f.getPath()}") + } } override def read[T: ClassTag](prefix: String): Seq[T] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index cf77c86d760c..70f21fbe0de8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait LeaderElectionAgent { - val masterActor: LeaderElectable + val masterInstance: LeaderElectable def stop() {} // to avoid noops in implementations. } @@ -37,7 +37,7 @@ trait LeaderElectable { } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) +private[spark] class MonarchyLeaderAgent(val masterInstance: LeaderElectable) extends LeaderElectionAgent { - masterActor.electedLeader() + masterInstance.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e38e437fe1c5..fc42bf06e40a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,9 +21,11 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps import scala.util.Random @@ -56,6 +58,10 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val rebuildUIThread = + ThreadUtils.newDaemonSingleThreadExecutor("master-rebuild-ui-thread") + private val rebuildUIContext = ExecutionContext.fromExecutor(rebuildUIThread) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -78,7 +84,8 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 - private val appIdToUI = new HashMap[String, SparkUI] + // Using ConcurrentHashMap so that master-rebuild-ui-thread can add a UI after asyncRebuildUI + private val appIdToUI = new ConcurrentHashMap[String, SparkUI] private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -127,14 +134,8 @@ private[deploy] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServer = - if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) - } else { - None - } - private val restServerBoundPort = restServer.map(_.start()) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) @@ -148,6 +149,12 @@ private[deploy] class Master( } }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) + masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() applicationMetricsSystem.start() @@ -191,6 +198,7 @@ private[deploy] class Master( checkForWorkerTimeOutTask.cancel(true) } forwardMessageThread.shutdownNow() + rebuildUIThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -233,31 +241,6 @@ private[deploy] class Master( System.exit(0) } - case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - workerHost, workerPort, cores, Utils.megabytesToString(memory))) - if (state == RecoveryState.STANDBY) { - // ignore, don't send response - } else if (idToWorker.contains(id)) { - workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) - } else { - val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerUiPort, publicAddress) - if (registerWorker(worker)) { - persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) - schedule() - } else { - val workerAddress = worker.endpoint.address - logWarning("Worker registration failed. Attempted to re-register worker at same " + - "address: " + workerAddress) - workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress)) - } - } - } - case RegisterApplication(description, driver) => { // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { @@ -278,9 +261,17 @@ private[deploy] class Master( execOption match { case Some(exec) => { val appInfo = idToApp(appId) + val oldState = exec.state exec.state = state - if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } + + if (state == ExecutorState.RUNNING) { + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") + appInfo.resetRetryCount() + } + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,9 +375,38 @@ private[deploy] class Master( case CheckForWorkerTimeOut => { timeOutDeadWorkers() } + + case AttachCompletedRebuildUI(appId) => + // An asyncRebuildSparkUI has completed, so need to attach to master webUi + Option(appIdToUI.get(appId)).foreach { ui => webUi.attachSparkUI(ui) } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + workerHost, workerPort, cores, Utils.megabytesToString(memory))) + if (state == RecoveryState.STANDBY) { + context.reply(MasterInStandby) + } else if (idToWorker.contains(id)) { + context.reply(RegisterWorkerFailed("Duplicate worker ID")) + } else { + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + workerRef, workerUiPort, publicAddress) + if (registerWorker(worker)) { + persistenceEngine.addWorker(worker) + context.reply(RegisteredWorker(self, masterWebUiUrl)) + schedule() + } else { + val workerAddress = worker.endpoint.address + logWarning("Worker registration failed. Attempted to re-register worker at same " + + "address: " + workerAddress) + context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) + } + } + } + case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + @@ -581,20 +601,22 @@ private[deploy] class Master( /** Return whether the specified worker can launch an executor for this app. */ def canLaunchExecutor(pos: Int): Boolean = { + val keepScheduling = coresToAssign >= minCoresPerExecutor + val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor + // If we allow multiple executors per worker, then we can always launch new executors. - // Otherwise, we may have already started assigning cores to the executor on this worker. + // Otherwise, if there is already an executor on this worker, just give it more cores. val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0 - val underLimit = - if (launchingNewExecutor) { - assignedExecutors.sum + app.executors.size < app.executorLimit - } else { - true - } - val assignedMemory = assignedExecutors(pos) * memoryPerExecutor - usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor && - usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor && - coresToAssign >= minCoresPerExecutor && - underLimit + if (launchingNewExecutor) { + val assignedMemory = assignedExecutors(pos) * memoryPerExecutor + val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor + val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit + keepScheduling && enoughCores && enoughMemory && underLimit + } else { + // We're adding cores to an existing executor, so no need + // to check memory and executor limits + keepScheduling && enoughCores + } } // Keep launching executors until no more workers can accommodate any @@ -700,8 +722,8 @@ private[deploy] class Master( worker.addExecutor(exec) worker.endpoint.send(LaunchExecutor(masterUrl, exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) - exec.application.driver.send(ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) + exec.application.driver.send( + ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -766,7 +788,8 @@ private[deploy] class Master( ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) + val appId = newApplicationId(date) + new ApplicationInfo(now, appId, desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { @@ -798,7 +821,7 @@ private[deploy] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { - appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) } + Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) }) completedApps.trimStart(toRemove) @@ -807,7 +830,7 @@ private[deploy] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - rebuildSparkUI(app) + asyncRebuildSparkUI(app) for (exec <- app.executors.values) { killExecutor(exec) @@ -912,49 +935,57 @@ private[deploy] class Master( * Return the UI if successful, else None */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { + val futureUI = asyncRebuildSparkUI(app) + Await.result(futureUI, Duration.Inf) + } + + /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ + private[master] def asyncRebuildSparkUI(app: ApplicationInfo): Future[Option[SparkUI]] = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" - try { - val eventLogDir = app.desc.eventLogDir - .getOrElse { - // Event logging is not enabled for this application - app.desc.appUiUrl = notFoundBasePath - return None - } - + val eventLogDir = app.desc.eventLogDir + .getOrElse { + // Event logging is disabled for this application + app.appUIUrlAtHistoryServer = Some(notFoundBasePath) + return Future.successful(None) + } + val futureUI = Future { val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + - EventLoggingListener.IN_PROGRESS)) + EventLoggingListener.IN_PROGRESS)) - if (inProgressExists) { + val eventLogFile = if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") - } - - val (eventLogFile, status) = if (inProgressExists) { - (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + eventLogFilePrefix + EventLoggingListener.IN_PROGRESS } else { - (eventLogFilePrefix, " (completed)") + eventLogFilePrefix } val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) - val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) try { - replayBus.replay(logInput, eventLogFile, maybeTruncated) + replayBus.replay(logInput, eventLogFile, inProgressExists) } finally { logInput.close() } - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) - // Application UI is successfully rebuilt, so link the Master UI to it - app.desc.appUiUrl = ui.basePath + Some(ui) - } catch { + }(rebuildUIContext) + + futureUI.onSuccess { case Some(ui) => + appIdToUI.put(app.id, ui) + self.send(AttachCompletedRebuildUI(app.id)) + // Application UI is successfully rebuilt, so link the Master UI to it + // NOTE - app.appUIUrlAtHistoryServer is volatile + app.appUIUrlAtHistoryServer = Some(ui.basePath) + }(ThreadUtils.sameThread) + + futureUI.onFailure { case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" @@ -962,8 +993,8 @@ private[deploy] class Master( logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" - None + app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") + case e: Exception => // Relay exception message to application UI page val title = s"Application history load error (${app.id})" @@ -971,9 +1002,11 @@ private[deploy] class Master( var msg = s"Exception in replaying log for application $appName!" logError(msg, e) msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" - None - } + app.appUIUrlAtHistoryServer = + Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") + }(ThreadUtils.sameThread) + + futureUI } /** Generate a new app ID given a app's submission date */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 68c937188b33..a055d097674c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -38,5 +38,7 @@ private[master] object MasterMessages { case object BoundPortsRequest - case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) + case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) + + case class AttachCompletedRebuildUI(appId: String) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 6fdff86f66e0..d317206a614f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -22,7 +22,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil -private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, +private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -73,10 +73,10 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElecta private def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor.electedLeader() + masterInstance.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor.revokedLeadership() + masterInstance.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 563831cc6b8d..540e802420ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework @@ -49,8 +49,8 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer } override def read[T: ClassTag](prefix: String): Seq[T] = { - val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) - file.map(deserializeFromFile[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala + .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten } override def close() { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index e28e7e379ac9..f405aa2bdc8b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -76,7 +76,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • -
  • Application Detail UI
  • +
  • Application Detail UI
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c3e20ebf8d6e..ee539dd1f511 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -206,7 +206,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {killLink} - {app.desc.name} + {app.desc.name} {app.coresGranted} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 6174fc11f83d..e41554a5a6d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,14 +28,17 @@ import org.apache.spark.ui.JettyUtils._ * Web UI server for the standalone master. */ private[master] -class MasterWebUI(val master: Master, requestedPort: Int) +class MasterWebUI( + val master: Master, + requestedPort: Int, + customMasterPage: Option[MasterPage] = None) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - val masterPage = new MasterPage(this) + val masterPage = customMasterPage.getOrElse(new MasterPage(this)) initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 5d4e5b899dfd..389eff5e0645 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{ShutdownHookManager, SignalLogger} import org.apache.spark.{Logging, SecurityManager, SparkConf} /* @@ -103,14 +103,11 @@ private[mesos] object MesosClusterDispatcher extends Logging { } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() - val shutdownHook = new Thread() { - override def run() { - logInfo("Shutdown hook is shutting down dispatcher") - dispatcher.stop() - dispatcher.awaitShutdown() - } + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() } - Runtime.getRuntime.addShutdownHook(shutdownHook) dispatcher.awaitShutdown() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 061857476a8a..8ffcfc0878a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.mesos import java.net.SocketAddress +import java.nio.ByteBuffer import scala.collection.mutable @@ -34,7 +35,7 @@ import org.apache.spark.network.util.TransportConf * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. */ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) - extends ExternalShuffleBlockHandler(transportConf) with Logging { + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { // Stores a map of driver socket addresses to app ids private val connectedApps = new mutable.HashMap[SocketAddress, String] @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo } } connectedApps(address) = appId - callback.onSuccess(new Array[Byte](0)) + callback.onSuccess(ByteBuffer.allocate(0)) case _ => super.handleMessage(message, client, callback) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index e8ef60bd5428..bc67fd460d9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -46,7 +46,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") val schedulerHeaders = Seq("Scheduler property", "Value") val commandEnvHeaders = Seq("Command environment variable", "Value") val launchedHeaders = Seq("Launched property", "Value") - val commandHeaders = Seq("Comamnd property", "Value") + val commandHeaders = Seq("Command property", "Value") val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") val driverDescription = Iterable.apply(driverState.description) val submissionState = Iterable.apply(driverState.submissionState) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 1fe956320a1b..f0dd667ea1b2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -19,16 +19,19 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} import java.net.{ConnectException, HttpURLConnection, SocketException, URL} +import java.util.concurrent.TimeoutException import javax.servlet.http.HttpServletResponse import scala.collection.mutable +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} /** * A client that submits applications to a [[RestSubmissionServer]]. @@ -225,7 +228,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { * Exposed for testing. */ private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { - try { + import scala.concurrent.ExecutionContext.Implicits.global + val responseFuture = Future { val dataStream = if (connection.getResponseCode == HttpServletResponse.SC_OK) { connection.getInputStream @@ -251,11 +255,15 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { throw new SubmitRestProtocolException( s"Message received from server was not a response:\n${unexpected.toJson}") } - } catch { + } + + try { Await.result(responseFuture, 10.seconds) } catch { case unreachable @ (_: FileNotFoundException | _: SocketException) => throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => throw new SubmitRestProtocolException("Malformed response received from server", malformed) + case timeout: TimeoutException => + throw new SubmitRestConnectionException("No response from server", timeout) } } @@ -392,15 +400,14 @@ private[spark] object RestSubmissionClient { mainClass: String, appArgs: Array[String], conf: SparkConf, - env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + env: Map[String, String] = Map()): SubmitRestProtocolResponse = { val master = conf.getOption("spark.master").getOrElse { throw new IllegalArgumentException("'spark.master' must be set.") } val sparkProperties = conf.getAll.toMap - val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( - appResource, mainClass, appArgs, sparkProperties, environmentVariables) + appResource, mainClass, appArgs, sparkProperties, env) client.createSubmission(submitRequest) } @@ -413,6 +420,16 @@ private[spark] object RestSubmissionClient { val mainClass = args(1) val appArgs = args.slice(2, args.size) val conf = new SparkConf - run(appResource, mainClass, appArgs, conf) + val env = filterSystemEnvironment(sys.env) + run(appResource, mainClass, appArgs, conf, env) + } + + /** + * Filter non-spark environment variables from any environment. + */ + private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { + env.filter { case (k, _) => + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 868cc35d06ef..c0b93596508f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -94,7 +94,12 @@ private[mesos] class MesosSubmitRequestServlet( val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // We don't want to pass down SPARK_HOME when launching Spark apps + // with Mesos cluster mode since it's populated by default on the client and it will + // cause spark-submit script to look for files in SPARK_HOME instead. + // We only need the ability to specify where to find spark-submit script + // which user can user spark.executor.home or spark.home configurations. + val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 45a3f4304543..ce02ee203a4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} -import java.lang.System._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import org.apache.spark.Logging @@ -62,7 +61,7 @@ object CommandUtils extends Logging { // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() - cmd.toSeq ++ Seq(command.mainClass) ++ command.arguments + cmd.asScala ++ Seq(command.mainClass) ++ command.arguments } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index ec51c3d935d8..89159ff5e2b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -172,8 +172,8 @@ private[deploy] class DriverRunner( CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val header = "Launch Command: %s\n%s\n\n".format( - builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") + val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } @@ -229,6 +229,6 @@ private[deploy] trait ProcessBuilderLike { private[deploy] object ProcessBuilderLike { def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { override def start(): Process = processBuilder.start() - override def command: Seq[String] = processBuilder.command() + override def command: Seq[String] = processBuilder.command().asScala } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 29a504228557..9a42487bb37a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -70,7 +70,13 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will + // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`. + if (state == ExecutorState.RUNNING) { + state = ExecutorState.FAILED + } + killProcess(Some("Worker shutting down")) } } /** @@ -91,7 +97,11 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + try { + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + } catch { + case e: IllegalStateException => logWarning(e.getMessage(), e) + } } /** Stop this executor runner, including killing the process it launched */ @@ -102,7 +112,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } @@ -128,7 +138,8 @@ private[deploy] class ExecutorRunner( val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") + logInfo(s"Launch command: $formattedCommand") builder.directory(executorDir) builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) @@ -144,7 +155,7 @@ private[deploy] class ExecutorRunner( process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) + formattedCommand, "=" * 40) // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c82a7ccab54d..f41efb097b4b 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -24,10 +24,9 @@ import java.util.{UUID, Date} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext -import scala.util.Random +import scala.util.{Failure, Random, Success} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -147,12 +146,10 @@ private[deploy] class Worker( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "worker-register-master-threadpool", + masterRpcAddresses.size // Make sure we can register with all masters at the same time + ) var coresUsed = 0 var memoryUsed = 0 @@ -214,8 +211,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -228,7 +224,7 @@ private[deploy] class Worker( /** * Re-register with the master because a network failure or a master failure has occurred. * If the re-registration attempt threshold is exceeded, the worker exits with error. - * Note that for thread-safety this should only be called from the actor. + * Note that for thread-safety this should only be called from the rpcEndpoint. */ private def reregisterWithMaster(): Unit = { Utils.tryOrExit { @@ -272,8 +268,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -330,7 +325,7 @@ private[deploy] class Worker( registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(ReregisterWithMaster) + Option(self).foreach(_.send(ReregisterWithMaster)) } }, INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, @@ -342,30 +337,60 @@ private[deploy] class Worker( } } - override def receive: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) - registered = true - changeMaster(masterRef, masterWebUiUrl) - forwordMessageScheduler.scheduleAtFixedRate(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - self.send(SendHeartbeat) - } - }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) - if (CLEANUP_ENABLED) { - logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = { + masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + .onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" + case Success(msg) => + Utils.tryLogNonFatalError { + handleRegisterResponse(msg) + } + case Failure(e) => + logError(s"Cannot register with master: ${masterEndpoint.address}", e) + System.exit(1) + }(ThreadUtils.sameThread) + } + + private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { + msg match { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + registered = true + changeMaster(masterRef, masterWebUiUrl) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(WorkDirCleanup) + self.send(SendHeartbeat) } - }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) - } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) + if (CLEANUP_ENABLED) { + logInfo( + s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) + } + case RegisterWorkerFailed(message) => + if (!registered) { + logError("Worker registration failed: " + message) + System.exit(1) + } + + case MasterInStandby => + // Ignore. Master not yet ready. + } + } + + override def receive: PartialFunction[Any, Unit] = synchronized { case SendHeartbeat => if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => - // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker + // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { @@ -399,12 +424,6 @@ private[deploy] class Worker( map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) - case RegisterWorkerFailed(message) => - if (!registered) { - logError("Worker registration failed: " + message) - System.exit(1) - } - case ReconnectWorker(masterUrl) => logInfo(s"Master with url $masterUrl requested this worker to reconnect.") registerWithMaster() @@ -427,7 +446,9 @@ private[deploy] class Worker( // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + appDir.getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs @@ -446,7 +467,7 @@ private[deploy] class Worker( executorDir, workerUri, conf, - appLocalDirs, ExecutorState.LOADING) + appLocalDirs, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -669,7 +690,7 @@ private[deploy] object Worker extends Logging { val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.masters, args.workDir) + args.memory, args.masters, args.workDir, conf = conf) rpcEnv.awaitTermination() } @@ -684,7 +705,7 @@ private[deploy] object Worker extends Logging { workerNumber: Option[Int] = None, conf: SparkConf = new SparkConf): RpcEnv = { - // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index fae5640b9a21..ab56fde938ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -24,14 +24,13 @@ import org.apache.spark.rpc._ * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) +private[spark] class WorkerWatcher( + override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false) extends RpcEndpoint with Logging { - override def onStart() { - logInfo(s"Connecting to worker $workerUrl") - if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByURI(workerUrl) - } + logInfo(s"Connecting to worker $workerUrl") + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) } // Used to avoid shutting down JVM during tests @@ -40,10 +39,8 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin // true rather than calling `System.exit`. The user can check `isShutDown` to know if // `exitNonZero` is called. private[deploy] var isShutDown = false - private[deploy] def setTesting(testing: Boolean) = isTesting = testing - private var isTesting = false - // Lets us filter events only from the worker's actor system + // Lets filter events only from the worker's rpc system private val expectedAddress = RpcAddress.fromURIString(workerUrl) private def isWorker(address: RpcAddress) = expectedAddress == address @@ -62,7 +59,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (isWorker(remoteAddress)) { // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") + logError(s"Lost connection to worker rpc endpoint $workerUrl. Exiting.") exitNonZero() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598..1a0598e50dcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index fcd76ec52742..c2ebf3059621 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { - Utils.checkHostPort(hostPort, "Expected hostport") - var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -59,12 +57,12 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisteredExecutor.type]( + ref.ask[RegisterExecutorResponse]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { - Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse } case Failure(e) => { logError(s"Cannot register with driver: $driverUrl", e) @@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive: PartialFunction[Any, Unit] = { - case RegisteredExecutor => + case RegisteredExecutor(hostname) => logInfo("Successfully registered with driver") - val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => @@ -110,6 +107,11 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + // Cannot shutdown here because an ack may need to be sent back to the caller. So send + // a message to self to actually do the shutdown. + self.send(Shutdown) + + case Shutdown => executor.stop() stop() rpcEnv.shutdown() @@ -158,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname, port, executorConf, - new SecurityManager(executorConf)) + new SecurityManager(executorConf), + clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) @@ -183,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) - // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore. - val boundPort = env.conf.getInt("spark.executor.port", 0) - assert(boundPort != 0) - - // Start the CoarseGrainedExecutorBackend endpoint. - val sparkHostPort = hostname + ":" + boundPort + // SparkEnv will set spark.executor.port if the rpc env is listening for incoming + // connections (e.g., if it's using akka). Otherwise, the executor is running in + // client mode only, and does not accept incoming connections. + val sparkHostPort = env.conf.getOption("spark.executor.port").map { port => + hostname + ":" + port + }.orNull env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da..7d84889a2def 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7bc7fce7ae8d..552b644d13aa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,22 +17,22 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -85,10 +85,6 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an RpcEndpoint for receiving RPCs from the driver - private val executorEndpoint = env.rpcEnv.setupEndpoint( - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) - // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) @@ -113,6 +109,10 @@ private[spark] class Executor( // Executor for the heartbeat task. private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + startDriverHeartbeater() def launchTask( @@ -136,7 +136,6 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.rpcEnv.stop(executorEndpoint) heartbeater.shutdown() heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() @@ -147,7 +146,7 @@ private[spark] class Executor( /** Returns the total amount of time this JVM process has spent in garbage collection. */ private def computeTotalGcTime(): Long = { - ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum } class TaskRunner( @@ -179,7 +178,7 @@ private[spark] class Executor( } override def run(): Unit = { - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -249,6 +248,7 @@ private[spark] class Executor( m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) + m.updateAccumulators() } val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) @@ -300,11 +300,20 @@ private[spark] class Executor( task.metrics.map { m => m.setExecutorRunTime(System.currentTimeMillis() - taskStart) m.setJvmGCTime(computeTotalGcTime() - startGCTime) + m.updateAccumulators() m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. @@ -355,9 +364,9 @@ private[spark] class Executor( val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], - classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, _userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -406,16 +415,13 @@ private[spark] class Executor( } } - private val heartbeatReceiverRef = - RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { // list of (task id, metrics) to send back to the driver val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() val curGCTime = computeTotalGcTime() - for (taskRunner <- runningTasks.values()) { + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.updateShuffleReadMetrics() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 293c512f8b70..d16f4a1fc4e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.util.concurrent.ThreadPoolExecutor -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -30,7 +30,7 @@ private[spark] class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) + FileSystem.getAllStatistics.asScala.find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a9..c9f18ebc7f0e 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} @@ -28,7 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -55,7 +55,7 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. - val cpusPerTask = executorInfo.getResourcesList + val cpusPerTask = executorInfo.getResourcesList.asScala .find(_.getName == "cpus") .map(_.getScalar.getValue.toInt) .getOrElse(0) @@ -63,6 +63,11 @@ private[spark] class MesosExecutorBackend logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus") this.driver = driver + // Set a context class loader to be picked up by the serializer. Without this call + // the serializer would default to the null class loader, and fail to find Spark classes + // See SPARK-10986. + Thread.currentThread().setContextClassLoader(this.getClass.getClassLoader) + val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) val conf = new SparkConf(loadDefaults = true).setAll(properties) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 6cda7772f77b..280e7a5fe893 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -19,15 +19,14 @@ package org.apache.spark.input import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import com.google.common.io.ByteStreams +import com.google.common.io.{Closeables, ByteStreams} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil /** @@ -44,12 +43,9 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum - - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } @@ -86,7 +82,6 @@ private[spark] abstract class StreamBasedRecordReader[T]( if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) - fileIn.close() // if it has not been open yet, close does nothing key = fileIn.getPath processed = true true @@ -132,19 +127,12 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat * @note TaskAttemptContext is not serializable resulting in the confBytes construct * @note CombineFileSplit is not serializable resulting in the splitBytes construct */ -@Experimental class PortableDataStream( - @transient isplit: CombineFileSplit, - @transient context: TaskAttemptContext, + isplit: CombineFileSplit, + context: TaskAttemptContext, index: Integer) extends Serializable { - // transient forces file to be reopened after being serialization - // it is also used for non-serializable classes - - @transient private var fileIn: DataInputStream = null - @transient private var isOpen = false - private val confBytes = { val baos = new ByteArrayOutputStream() SparkHadoopUtil.get.getConfigurationFromJobContext(context). @@ -180,40 +168,34 @@ class PortableDataStream( } /** - * Create a new DataInputStream from the split and context + * Create a new DataInputStream from the split and context. The user of this method is responsible + * for closing the stream after usage. */ def open(): DataInputStream = { - if (!isOpen) { - val pathp = split.getPath(index) - val fs = pathp.getFileSystem(conf) - fileIn = fs.open(pathp) - isOpen = true - } - fileIn + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fs.open(pathp) } /** * Read the file as a byte array */ def toArray(): Array[Byte] = { - open() - val innerBuffer = ByteStreams.toByteArray(fileIn) - close() - innerBuffer + val stream = open() + try { + ByteStreams.toByteArray(stream) + } finally { + Closeables.close(stream, true) + } } /** - * Close the file (if it is currently open) + * Closing the PortableDataStream is not needed anymore. The user either can use the + * PortableDataStream to get a DataInputStream (which the user needs to close after usage), + * or a byte array. */ + @deprecated("Closing the PortableDataStream is not needed anymore.", "1.6.0") def close(): Unit = { - if (isOpen) { - try { - fileIn.close() - isOpen = false - } catch { - case ioe: java.io.IOException => // do nothing - } - } } def getPath(): String = path diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index aaef7c74eea3..413408723b54 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -17,9 +17,10 @@ package org.apache.spark.input -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat @@ -33,14 +34,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext */ private[spark] class WholeTextFileInputFormat - extends CombineFileInputFormat[String, String] with Configurable { + extends CombineFileInputFormat[Text, Text] with Configurable { override protected def isSplitable(context: JobContext, file: Path): Boolean = false override def createRecordReader( split: InputSplit, - context: TaskAttemptContext): RecordReader[String, String] = { - + context: TaskAttemptContext): RecordReader[Text, Text] = { val reader = new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader]) reader.setConf(getConf) @@ -52,10 +52,8 @@ private[spark] class WholeTextFileInputFormat * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 31bde8a78f3c..b56b2aa88a41 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -49,7 +49,7 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] with Configurable { + extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) private[this] val fs = path.getFileSystem( @@ -58,8 +58,8 @@ private[spark] class WholeTextFileRecordReader( // True means the current file has been processed, then skip it. private[this] var processed = false - private[this] val key = path.toString - private[this] var value: String = null + private[this] val key: Text = new Text(path.toString) + private[this] var value: Text = null override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} @@ -67,9 +67,9 @@ private[spark] class WholeTextFileRecordReader( override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey: String = key + override def getCurrentKey: Text = key - override def getCurrentValue: String = value + override def getCurrentValue: Text = value override def nextKeyValue(): Boolean = { if (!processed) { @@ -83,7 +83,7 @@ private[spark] class WholeTextFileRecordReader( ByteStreams.toByteArray(fileIn) } - value = new Text(innerBuffer).toString + value = new Text(innerBuffer) Closeables.close(fileIn, false) processed = true true diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efc..ca74eedf89be 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -47,6 +47,11 @@ trait CompressionCodec { private[spark] object CompressionCodec { private val configKey = "spark.io.compression.codec" + + private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { + codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + } + private val shortCompressionCodecNames = Map( "lz4" -> classOf[LZ4CompressionCodec].getName, "lzf" -> classOf[LZFCompressionCodec].getName, @@ -148,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala new file mode 100644 index 000000000000..a5d41a1eeb47 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.net.{InetAddress, Socket} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.launcher.LauncherProtocol._ +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A class that can be used to talk to a launcher server. Users should extend this class to + * provide implementation for the abstract methods. + * + * See `LauncherServer` for an explanation of how launcher communication works. + */ +private[spark] abstract class LauncherBackend { + + private var clientThread: Thread = _ + private var connection: BackendConnection = _ + private var lastState: SparkAppHandle.State = _ + @volatile private var _isConnected = false + + def connect(): Unit = { + val port = sys.env.get(LauncherProtocol.ENV_LAUNCHER_PORT).map(_.toInt) + val secret = sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET) + if (port != None && secret != None) { + val s = new Socket(InetAddress.getLoopbackAddress(), port.get) + connection = new BackendConnection(s) + connection.send(new Hello(secret.get, SPARK_VERSION)) + clientThread = LauncherBackend.threadFactory.newThread(connection) + clientThread.start() + _isConnected = true + } + } + + def close(): Unit = { + if (connection != null) { + try { + connection.close() + } finally { + if (clientThread != null) { + clientThread.join() + } + } + } + } + + def setAppId(appId: String): Unit = { + if (connection != null) { + connection.send(new SetAppId(appId)) + } + } + + def setState(state: SparkAppHandle.State): Unit = { + if (connection != null && lastState != state) { + connection.send(new SetState(state)) + lastState = state + } + } + + /** Return whether the launcher handle is still connected to this backend. */ + def isConnected(): Boolean = _isConnected + + /** + * Implementations should provide this method, which should try to stop the application + * as gracefully as possible. + */ + protected def onStopRequest(): Unit + + /** + * Callback for when the launcher handle disconnects from this backend. + */ + protected def onDisconnected() : Unit = { } + + private def fireStopRequest(): Unit = { + val thread = LauncherBackend.threadFactory.newThread(new Runnable() { + override def run(): Unit = Utils.tryLogNonFatalError { + onStopRequest() + } + }) + thread.start() + } + + private class BackendConnection(s: Socket) extends LauncherConnection(s) { + + override protected def handle(m: Message): Unit = m match { + case _: Stop => + fireStopRequest() + + case _ => + throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") + } + + override def close(): Unit = { + try { + super.close() + } finally { + onDisconnected() + _isConnected = false + } + } + + } + +} + +private object LauncherBackend { + + val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend") + +} diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 9be98723aed1..a2add6161728 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.launcher import java.io.File import java.util.{HashMap => JHashMap, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.deploy.Command @@ -32,7 +32,7 @@ import org.apache.spark.deploy.Command private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) extends AbstractCommandBuilder { - childEnv.putAll(command.environment) + childEnv.putAll(command.environment.asJava) childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) override def buildCommand(env: JMap[String, String]): JList[String] = { @@ -40,7 +40,7 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm cmd.add(s"-Xms${memoryMb}M") cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - addPermGenSizeOpt(cmd) + CommandBuilderUtils.addPermGenSizeOpt(cmd) addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87df42748be4..f7298e8d5c62 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.util.{Utils => SparkUtils} @@ -90,10 +91,9 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { - val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) // Called after we have decided to commit def performCommit(): Unit = { @@ -121,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -131,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -142,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala new file mode 100644 index 000000000000..dbb0ad8d5c67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * Implements policies and bookkeeping for sharing a adjustable-sized pool of memory between tasks. + * + * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up + * to a large amount first and then causing others to spill to disk repeatedly. + * + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this + * set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. + * + * @param lock a [[MemoryManager]] instance to synchronize on + * @param poolName a human-readable name for this pool, for use in log messages + */ +private[memory] class ExecutionMemoryPool( + lock: Object, + poolName: String + ) extends MemoryPool(lock) with Logging { + + /** + * Map from taskAttemptId -> memory consumption in bytes + */ + @GuardedBy("lock") + private val memoryForTask = new mutable.HashMap[Long, Long]() + + override def memoryUsed: Long = lock.synchronized { + memoryForTask.values.sum + } + + /** + * Returns the memory consumption, in bytes, for the given task. + */ + def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized { + memoryForTask.getOrElse(taskAttemptId, 0L) + } + + /** + * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes + * obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in + * one parameter (Long) that represents the desired amount of memory by + * which this pool should be expanded. + * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool + * at this given moment. This is not a field because the max pool + * size is variable in certain cases. For instance, in unified + * memory management, the execution pool can be expanded by evicting + * cached blocks, thereby shrinking the storage pool. + * + * @return the number of bytes granted to the task. + */ + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long, + maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit, + computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized { + assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + + // TODO: clean up this clunky method signature + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` + if (!memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + lock.notifyAll() + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = memoryForTask.keys.size + val curMem = memoryForTask(taskAttemptId) + + // In every iteration of this loop, we should first try to reclaim any borrowed execution + // space from storage. This is necessary because of the potential race condition where new + // storage blocks may steal the free execution memory that this task was waiting for. + maybeGrowPool(numBytes - memoryFree) + + // Maximum size the pool would have after potentially growing the pool. + // This is used to compute the upper bound of how much memory each task can occupy. This + // must take into account potential free memory as well as the amount this pool currently + // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management, + // we did not take into account space that could have been freed by evicting cached blocks. + val maxPoolSize = computeMaxPoolSize() + val maxMemoryPerTask = maxPoolSize / numActiveTasks + val minMemoryPerTask = poolSize / (2 * numActiveTasks) + + // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks + val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, memoryFree) + + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) { + logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() + } else { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** + * Release `numBytes` of memory acquired by the given task. + */ + def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized { + val curMem = memoryForTask.getOrElse(taskAttemptId, 0L) + var memoryToFree = if (curMem < numBytes) { + logWarning( + s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " + + s"of memory from the $poolName pool") + curMem + } else { + numBytes + } + if (memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) -= memoryToFree + if (memoryForTask(taskAttemptId) <= 0) { + memoryForTask.remove(taskAttemptId) + } + } + lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized { + val numBytesToFree = getMemoryUsageForTask(taskAttemptId) + releaseMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala new file mode 100644 index 000000000000..e707e27d96b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.memory.MemoryAllocator + +/** + * An abstract memory manager that enforces how memory is shared between execution and storage. + * + * In this context, execution memory refers to that used for computation in shuffles, joins, + * sorts and aggregations, while storage memory refers to that used for caching and propagating + * internal data across the cluster. There exists one MemoryManager per JVM. + */ +private[spark] abstract class MemoryManager( + conf: SparkConf, + numCores: Int, + storageMemory: Long, + onHeapExecutionMemory: Long) extends Logging { + + // -- Methods related to memory allocation policies and bookkeeping ------------------------------ + + @GuardedBy("this") + protected val storageMemoryPool = new StorageMemoryPool(this) + @GuardedBy("this") + protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "on-heap execution") + @GuardedBy("this") + protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "off-heap execution") + + storageMemoryPool.incrementPoolSize(storageMemory) + onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeap.size", 0)) + + /** + * Total available memory for storage, in bytes. This amount can vary over time, depending on + * the MemoryManager implementation. + * In this model, this is equivalent to the amount of memory not occupied by execution. + */ + def maxStorageMemory: Long + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = synchronized { + storageMemoryPool.setMemoryStore(store) + } + + // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + + /** + * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * + * This extra method allows subclasses to differentiate behavior between acquiring storage + * memory and acquiring unroll memory. For instance, the memory management model in Spark + * 1.5 and before places a limit on the amount of space that can be freed from unrolling. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * + * @return whether all N bytes were successfully granted. + */ + def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + + /** + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + */ + private[memory] + def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long + + /** + * Release numBytes of execution memory belonging to the given task. + */ + private[memory] + def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) + } + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { + onHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) + + offHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) + } + + /** + * Release N bytes of storage memory. + */ + def releaseStorageMemory(numBytes: Long): Unit = synchronized { + storageMemoryPool.releaseMemory(numBytes) + } + + /** + * Release all storage memory acquired. + */ + final def releaseAllStorageMemory(): Unit = synchronized { + storageMemoryPool.releaseAllMemory() + } + + /** + * Release N bytes of unroll memory. + */ + final def releaseUnrollMemory(numBytes: Long): Unit = synchronized { + releaseStorageMemory(numBytes) + } + + /** + * Execution memory currently in use, in bytes. + */ + final def executionMemoryUsed: Long = synchronized { + onHeapExecutionMemoryPool.memoryUsed + offHeapExecutionMemoryPool.memoryUsed + } + + /** + * Storage memory currently in use, in bytes. + */ + final def storageMemoryUsed: Long = synchronized { + storageMemoryPool.memoryUsed + } + + /** + * Returns the execution memory consumption, in bytes, for the given task. + */ + private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { + onHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) + + offHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) + } + + // -- Fields related to Tungsten managed memory ------------------------------------------------- + + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryMode: MemoryMode = { + if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { + require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } + } + + /** + * The default page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + val pageSizeBytes: Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + val maxTungstenMemory: Long = tungstenMemoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize + } + val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } + + /** + * Allocates memory for use by Unsafe/Tungsten code. + */ + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = { + tungstenMemoryMode match { + case MemoryMode.ON_HEAP => MemoryAllocator.HEAP + case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE + } + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala new file mode 100644 index 000000000000..1b9edf9c43bd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +/** + * Manages bookkeeping for an adjustable-sized region of memory. This class is internal to + * the [[MemoryManager]]. See subclasses for more details. + * + * @param lock a [[MemoryManager]] instance, used for synchronization. We purposely erase the type + * to `Object` to avoid programming errors, since this object should only be used for + * synchronization purposes. + */ +private[memory] abstract class MemoryPool(lock: Object) { + + @GuardedBy("lock") + private[this] var _poolSize: Long = 0 + + /** + * Returns the current size of the pool, in bytes. + */ + final def poolSize: Long = lock.synchronized { + _poolSize + } + + /** + * Returns the amount of free memory in the pool, in bytes. + */ + final def memoryFree: Long = lock.synchronized { + _poolSize - memoryUsed + } + + /** + * Expands the pool by `delta` bytes. + */ + final def incrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + _poolSize += delta + } + + /** + * Shrinks the pool by `delta` bytes. + */ + final def decrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + require(delta <= _poolSize) + require(_poolSize - delta >= memoryUsed) + _poolSize -= delta + } + + /** + * Returns the amount of used memory in this pool (in bytes). + */ + def memoryUsed: Long +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala new file mode 100644 index 000000000000..3554b558f212 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockId, BlockStatus} + +/** + * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. + * + * The sizes of the execution and storage regions are determined through + * `spark.shuffle.memoryFraction` and `spark.storage.memoryFraction` respectively. The two + * regions are cleanly separated such that neither usage can borrow memory from the other. + */ +private[spark] class StaticMemoryManager( + conf: SparkConf, + maxOnHeapExecutionMemory: Long, + override val maxStorageMemory: Long, + numCores: Int) + extends MemoryManager( + conf, + numCores, + maxStorageMemory, + maxOnHeapExecutionMemory) { + + def this(conf: SparkConf, numCores: Int) { + this( + conf, + StaticMemoryManager.getMaxExecutionMemory(conf), + StaticMemoryManager.getMaxStorageMemory(conf), + numCores) + } + + // Max number of bytes worth of blocks to evict when unrolling + private val maxUnrollMemory: Long = { + (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong + } + + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + false + } else { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } + } + + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory + val freeMemory = storageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) + storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + } + + private[memory] + override def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + } +} + + +private[spark] object StaticMemoryManager { + + /** + * Return the total amount of memory available for the storage region, in bytes. + */ + private def getMaxStorageMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) + val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) + (systemMaxMemory * memoryFraction * safetyFraction).toLong + } + + /** + * Return the total amount of memory available for the execution region, in bytes. + */ + private def getMaxExecutionMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (systemMaxMemory * memoryFraction * safetyFraction).toLong + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala new file mode 100644 index 000000000000..70af83b5ee09 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} + +/** + * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage + * (caching). + * + * @param lock a [[MemoryManager]] instance to synchronize on + */ +private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { + + @GuardedBy("lock") + private[this] var _memoryUsed: Long = 0L + + override def memoryUsed: Long = lock.synchronized { + _memoryUsed + } + + private var _memoryStore: MemoryStore = _ + def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalStateException("memory store not initialized yet") + } + _memoryStore + } + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + val numBytesToFree = math.max(0, numBytes - memoryFree) + acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the amount of space to be freed through evicting blocks + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) + assert(memoryUsed <= poolSize) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + } + // NOTE: If the memory store evicts blocks, then those evictions will synchronously call + // back into this StorageMemoryPool in order to free memory. Therefore, these variables + // should have been updated. + val enoughMemory = numBytesToAcquire <= memoryFree + if (enoughMemory) { + _memoryUsed += numBytesToAcquire + } + enoughMemory + } + + def releaseMemory(size: Long): Unit = lock.synchronized { + if (size > _memoryUsed) { + logWarning(s"Attempted to release $size bytes of storage " + + s"memory when we only have ${_memoryUsed} bytes") + _memoryUsed = 0 + } else { + _memoryUsed -= size + } + } + + def releaseAllMemory(): Unit = lock.synchronized { + _memoryUsed = 0 + } + + /** + * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number + * of bytes removed from the pool's capacity. + */ + def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { + // First, shrink the pool by reclaiming free memory: + val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) + decrementPoolSize(spaceFreedByReleasingUnusedMemory) + val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory + if (remainingSpaceToFree > 0) { + // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, evictedBlocks) + val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum + // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do + // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. + decrementPoolSize(spaceFreedByEviction) + spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } else { + spaceFreedByReleasingUnusedMemory + } + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala new file mode 100644 index 000000000000..829f054dba0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} + +/** + * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that + * either side can borrow memory from the other. + * + * The region shared between execution and storage is a fraction of (the total heap space - 300MB) + * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary + * within this space is further determined by `spark.memory.storageFraction` (default 0.5). + * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. + * + * Storage can borrow as much execution memory as is free until execution reclaims its space. + * When this happens, cached blocks will be evicted from memory until sufficient borrowed + * memory is released to satisfy the execution memory request. + * + * Similarly, execution can borrow as much storage memory as is free. However, execution + * memory is *never* evicted by storage due to the complexities involved in implementing this. + * The implication is that attempts to cache blocks may fail if execution has already eaten + * up most of the storage space, in which case the new blocks will be evicted immediately + * according to their respective storage levels. + * + * @param storageRegionSize Size of the storage region, in bytes. + * This region is not statically reserved; execution can borrow from + * it if necessary. Cached blocks can be evicted only if actual + * storage memory usage exceeds this region. + */ +private[spark] class UnifiedMemoryManager private[memory] ( + conf: SparkConf, + val maxMemory: Long, + storageRegionSize: Long, + numCores: Int) + extends MemoryManager( + conf, + numCores, + storageRegionSize, + maxMemory - storageRegionSize) { + + // We always maintain this invariant: + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + + override def maxStorageMemory: Long = synchronized { + maxMemory - onHeapExecutionMemoryPool.memoryUsed + } + + /** + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + */ + override private[memory] def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + assert(numBytes >= 0) + memoryMode match { + case MemoryMode.ON_HEAP => + + /** + * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool. + * + * When acquiring memory for a task, the execution pool may need to make multiple + * attempts. Each attempt must be able to evict storage in case another task jumps in + * and caches a large block between the attempts. This is called once per attempt. + */ + def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { + if (extraMemoryNeeded > 0) { + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } + } + } + + /** + * The size the execution pool would have after evicting storage memory. + * + * The execution memory pool divides this quantity among the active tasks evenly to cap + * the execution memory allocation for each task. It is important to keep this greater + * than the execution pool size, which doesn't take into account potential memory that + * could be freed by evicting storage. Otherwise we may hit SPARK-12155. + * + * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness + * in execution memory allocation across tasks, Otherwise, a task may occupy more than + * its fair share of execution memory, mistakenly thinking that other tasks can acquire + * the portion of storage memory that cannot be evicted. + */ + def computeMaxExecutionPoolSize(): Long = { + maxMemory - math.min(storageMemoryUsed, storageRegionSize) + } + + onHeapExecutionMemoryPool.acquireMemory( + numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) + + case MemoryMode.OFF_HEAP => + // For now, we only support on-heap caching of data, so we do not need to interact with + // the storage pool when allocating off-heap memory. This will change in the future, though. + offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + } + + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + assert(numBytes >= 0) + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + return false + } + if (numBytes > storageMemoryPool.memoryFree) { + // There is not enough free memory in the storage pool, so try to borrow free memory from + // the execution pool. + val memoryBorrowedFromExecution = Math.min(onHeapExecutionMemoryPool.memoryFree, numBytes) + onHeapExecutionMemoryPool.decrementPoolSize(memoryBorrowedFromExecution) + storageMemoryPool.incrementPoolSize(memoryBorrowedFromExecution) + } + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } + + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, evictedBlocks) + } +} + +object UnifiedMemoryManager { + + // Set aside a fixed amount of memory for non-storage, non-execution purposes. + // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve + // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then + // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { + val maxMemory = getMaxMemory(conf) + new UnifiedMemoryManager( + conf, + maxMemory = maxMemory, + storageRegionSize = + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong, + numCores = numCores) + } + + /** + * Return the total amount of memory shared between execution and storage, in bytes. + */ + private def getMaxMemory(conf: SparkConf): Long = { + val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val reservedMemory = conf.getLong("spark.testing.reservedMemory", + if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val minSystemMemory = reservedMemory * 1.5 + if (systemMemory < minSystemMemory) { + throw new IllegalArgumentException(s"System memory $systemMemory must " + + s"be at least $minSystemMemory. Please use a larger heap size.") + } + val usableMemory = systemMemory - reservedMemory + val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) + (usableMemory * memoryFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala new file mode 100644 index 000000000000..3d00cd9cb637 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This package implements Spark's memory management system. This system consists of two main + * components, a JVM-wide memory manager and a per-task manager: + * + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution + * (memory used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual + * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. + * + * Internally, each of these components have additional abstractions for memory bookkeeping: + * + * - [[org.apache.spark.memory.MemoryConsumer]]s are clients of the TaskMemoryManager and + * correspond to individual operators and data structures within a task. The TaskMemoryManager + * receives memory allocation requests from MemoryConsumers and issues callbacks to consumers + * in order to trigger spilling when running low on memory. + * - [[org.apache.spark.memory.MemoryPool]]s are a bookkeeping abstraction used by the + * MemoryManager to track the division of memory between storage and execution. + * + * Diagrammatically: + * + * {{{ + * +-------------+ + * | MemConsumer |----+ +------------------------+ + * +-------------+ | +-------------------+ | MemoryManager | + * +--->| TaskMemoryManager |----+ | | + * +-------------+ | +-------------------+ | | +------------------+ | + * | MemConsumer |----+ | | | StorageMemPool | | + * +-------------+ +-------------------+ | | +------------------+ | + * | TaskMemoryManager |----+ | | + * +-------------------+ | | +------------------+ | + * +---->| |OnHeapExecMemPool | | + * * | | +------------------+ | + * * | | | + * +-------------+ * | | +------------------+ | + * | MemConsumer |----+ | | |OffHeapExecMemPool| | + * +-------------+ | +-------------------+ | | +------------------+ | + * +--->| TaskMemoryManager |----+ | | + * +-------------------+ +------------------------+ + * }}} + * + * + * There are two implementations of [[org.apache.spark.memory.MemoryManager]] which vary in how + * they handle the sizing of their memory pools: + * + * - [[org.apache.spark.memory.UnifiedMemoryManager]], the default in Spark 1.6+, enforces soft + * boundaries between storage and execution memory, allowing requests for memory in one region + * to be fulfilled by borrowing memory from the other. + * - [[org.apache.spark.memory.StaticMemoryManager]] enforces hard boundaries between storage + * and execution memory by statically partitioning Spark's memory and preventing storage and + * execution from borrowing memory from each other. This mode is retained only for legacy + * compatibility purposes. + */ +package object memory diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index d7495551ad23..dd2d325d8703 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics import java.io.{FileInputStream, InputStream} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex @@ -58,25 +59,20 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging { propertyCategories = subProperties(properties, INSTANCE_REGEX) if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) + val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala + for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX); + (k, v) <- defaultProperty if (prop.get(k) == null)) { + prop.put(k, v) } } } def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1).isDefined) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + prop.asScala.foreach { kv => + if (regex.findPrefixOf(kv._1.toString).isDefined) { + val regex(prefix, suffix) = kv._1.toString + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) } } subProperties diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3..fdf76d312db3 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() @@ -197,7 +197,7 @@ private[spark] class MetricsSystem private ( } } catch { case e: Exception => { - logError("Sink class " + classPath + " cannot be instantialized") + logError("Sink class " + classPath + " cannot be instantiated") throw e } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a3307..4193e1d21d3c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2..df8c21fb837e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -46,18 +47,18 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. @@ -65,7 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d650d5fe7308..40604a4da18d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,9 @@ package org.apache.spark.network.netty -import scala.collection.JavaConversions._ +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} @@ -28,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -41,7 +44,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -49,7 +52,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { @@ -58,7 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) @@ -67,7 +70,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps) + val server = transportContext.createServer(port, bootstraps.asJava) (server, server.getPort) } @@ -121,23 +124,16 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level)) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") - result.success() + result.success((): Unit) } override def onFailure(e: Throwable): Unit = { logError(s"Error while uploading block $blockId", e) @@ -149,7 +145,11 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } override def close(): Unit = { - server.close() - clientFactory.close() + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d68..84833f59d7af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala deleted file mode 100644 index 79cb0640c867..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -// private[spark] because we need to register them in Kryo -private[spark] case class GetBlock(id: BlockId) -private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) -private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) - -private[nio] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: BlockId = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = BlockId(idBuilder.toString) - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = typ - def getId: BlockId = id - def getData: ByteBuffer = data - def getLevel: StorageLevel = level - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) - buffer.putInt(typ).putInt(id.name.length) - id.name.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[nio] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala deleted file mode 100644 index f1c9ea8b64ca..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark._ -import org.apache.spark.storage.{StorageLevel, TestBlockId} - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) - extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int): BlockMessage = blockMessages(i) - - def iterator: Iterator[BlockMessage] = blockMessages.iterator - - def length: Int = blockMessages.length - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + - (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - Message.createBufferMessage(buffers) - } -} - -private[nio] object BlockMessageArray extends Logging { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear() - BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, - StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - logDebug("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - logDebug("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - logDebug("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - logDebug("Converted back to block message array") - // scalastyle:off println - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - // scalastyle:on println - } -} - - diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala deleted file mode 100644 index 9a9e22b0c236..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.BlockManager - - -private[nio] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size: Int = initialSize - - def currentSize(): Int = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - val security = if (isSecurityNeg) 1 else 0 - if (size == 0 && !gotChunkForSendingOnce) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, - hasError, security, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - val security = if (isSecurityNeg) 1 else 0 - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId(): Boolean = ackId != 0 - - def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - - override def toString: String = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala deleted file mode 100644 index 1499da07bb83..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ /dev/null @@ -1,619 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.LinkedList - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.control.NonFatal - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} - -private[nio] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, - val securityMgr: SecurityManager) - extends Logging { - - var sparkSaslServer: SparkSaslServer = null - var sparkSaslClient: SparkSaslClient = null - - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, - securityMgr_ : SecurityManager) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), - id_, securityMgr_) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /* channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def isSaslComplete(): Boolean - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - private def disposeSasl() { - if (sparkSaslServer != null) { - sparkSaslServer.dispose() - } - - if (sparkSaslClient != null) { - sparkSaslClient.dispose() - } - } - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key(): SelectionKey = channel.keyFor(selector) - - def getRemoteAddress(): InetSocketAddress = { - channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - } - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - disposeSasl() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Throwable) => Unit) { - onExceptionCallbacks.add(callback) - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks foreach { - callback => - try { - callback(this, e) - } catch { - case NonFatal(e) => { - logWarning("Ignored error in onExceptionCallback", e) - } - } - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.length + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[nio] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslClient != null) sparkSaslClient.isComplete() else false - } - - private class Outbox { - val messages = new LinkedList[Message]() - val defaultChunkSize = 65536 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized { - messages.add(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /* val message = messages(nextMessageToBeUsed) */ - - val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { - // only allow sending of security messages until sasl is complete - var pos = 0 - var securityMsg: Message = null - while (pos < messages.size() && securityMsg == null) { - if (messages.get(pos).isSecurityNeg) { - securityMsg = messages.remove(pos) - } - pos = pos + 1 - } - // didn't find any security messages and auth isn't completed so return - if (securityMsg == null) return None - securityMsg - } else { - messages.removeFirst() - } - - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.add(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox() - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /* channel.socket.setSendBufferSize(256 * 1024) */ - - override def getRemoteAddress(): InetSocketAddress = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def registerAfterAuth(): Unit = { - outbox.synchronized { - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try { - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => - logError("Error connecting to " + address, e) - callOnExceptionCallbacks(e) - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallbacks(e) - } - } - true - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as - // normal registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /* key.interestOps(0) */ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), - e) - callOnExceptionCallbacks(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection( - channel_ : SocketChannel, - selector_ : Selector, - id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(channel_, selector_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslServer != null) sparkSaslServer.isComplete() else false - } - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = header.securityNeg == 1 - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The receiver's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection, Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /* logDebug("Read " + bytesRead + " bytes for the buffer") */ - - if (currentChunk.buffer.remaining == 0) { - /* println("Filled buffer at " + System.currentTimeMillis) */ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip() - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala deleted file mode 100644 index 914391879038..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ /dev/null @@ -1,1157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.lang.ref.WeakReference -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.{ThreadUtils, Utils} - -import scala.util.Try -import scala.util.control.NonFatal - -private[nio] class ConnectionManager( - port: Int, - conf: SparkConf, - securityManager: SecurityManager, - name: String = "Connection manager") - extends Logging { - - /** - * Used by sendMessageReliably to track messages being sent. - * @param message the message that was sent - * @param connectionManagerId the connection manager that sent this message - * @param completionHandler callback that's invoked when the send has completed or failed - */ - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: Try[Message] => Unit) { - - def success(ackMessage: Message) { - if (ackMessage == null) { - failure(new NullPointerException) - } - else { - completionHandler(scala.util.Success(ackMessage)) - } - } - - def failWithoutAck() { - completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) - } - - def failure(e: Throwable) { - completionHandler(scala.util.Failure(e)) - } - } - - private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = - new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) - - private val ackTimeout = - conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", - conf.get("spark.network.timeout", "120s")) - - // Get the thread counts from the Spark Configuration. - // - // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, - // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is - // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" - // parameter is necessary. - private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) - private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4) - private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1) - - private val handleMessageExecutor = new ThreadPoolExecutor( - handlerThreadCount, - handlerThreadCount, - conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-message-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleMessageExecutor is not handled properly", t) - } - } - } - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - ioThreadCount, - ioThreadCount, - conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-read-write-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleReadWriteExecutor is not handled properly", t) - } - } - } - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : - // which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - connectThreadCount, - connectThreadCount, - conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-connect-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleConnectExecutor is not handled properly", t) - } - } - } - - private val serverChannel = ServerSocketChannel.open() - // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] - with SynchronizedMap[ConnectionId, SendingConnection] - private val connectionsByKey = - new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] - with SynchronizedMap[ConnectionManagerId, SendingConnection] - // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this - // map when messages are sent and are removed when acknowledgement messages are received or when - // acknowledgement timeouts expire - private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor( - ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) - - @volatile - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null - - private val authEnabled = securityManager.isAuthenticationEnabled() - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - private def startService(port: Int): (ServerSocketChannel, Int) = { - serverChannel.socket.bind(new InetSocketAddress(port)) - (serverChannel, serverChannel.socket.getLocalPort) - } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - // used in combination with the ConnectionManagerId to create unique Connection ids - // to be able to track asynchronous messages - private val idCount: AtomicInteger = new AtomicInteger(1) - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - @volatile private var isActive = true - private val selectorThread = new Thread("connection-manager-thread") { - override def run(): Unit = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - // start this thread last, since it invokes run(), which accesses members above - selectorThread.start() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in - // triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } catch { - case NonFatal(e) => { - logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallbacks(e) - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while (isActive) { - while (!registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue() - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue() - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + - intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions - // should be dealt with differently. - case e: CancelledKeyException => - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - 0 - - case e: ClosedSelectorException => - logDebug("Failed select() as selector is closed.", e) - return - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + - " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, - // key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, - securityManager) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - connection match { - case sendingConnection: SendingConnection => - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId - - messageStatuses.synchronized { - messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) - .foreach(status => { - logInfo("Notifying " + status) - status.failWithoutAck() - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case receivingConnection: ReceivingConnection => - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (!sendingConnectionOpt.isDefined) { - logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert(sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values - if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.failWithoutAck() - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case _ => logError("Unsupported type of connection.") - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Throwable) { - logInfo("Handling connection error on connection to " + - connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registrations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - try { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } catch { - case NonFatal(e) => { - logError("Error when handling messages from " + - connection.getRemoteConnectionManagerId(), e) - connection.callOnExceptionCallbacks(e) - } - } - } - } - handleMessageExecutor.execute(runnable) - /* handleMessage(connection, message) */ - } - - private def handleClientAuthentication( - waitingConn: SendingConnection, - securityMsg: SecurityMessage, - connectionId : ConnectionId) { - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } else { - var replyToken : Array[Byte] = null - try { - replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new IOException("Error creating security message") - sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => - logError("Error handling sasl client authentication", e) - waitingConn.close() - throw new IOException("Error evaluating sasl response: ", e) - } - } - } - - private def handleServerAuthentication( - connection: Connection, - securityMsg: SecurityMessage, - connectionId: ConnectionId) { - if (!connection.isSaslComplete()) { - logDebug("saslContext not established") - var replyToken : Array[Byte] = null - try { - connection.synchronized { - if (connection.sparkSaslServer == null) { - logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) - } - } - replyToken = connection.sparkSaslServer.response(securityMsg.getToken) - if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId + - " for: " + connectionId) - } else { - logDebug("Server sasl not completed: " + connection.connectionId + - " for: " + connectionId) - } - if (replyToken != null) { - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security Message") - sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } - } catch { - case e: Exception => { - logError("Error in server auth negotiation: " + e) - // It would probably be better to send an error message telling other side auth failed - // but for now just close - connection.close() - } - } - } else { - logDebug("connection already established for this connection id: " + connection.connectionId) - } - } - - - private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { - if (bufferMessage.isSecurityNeg) { - logDebug("This is security neg message") - - // parse as SecurityMessage - val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) - - connectionsAwaitingSasl.get(connectionId) match { - case Some(waitingConn) => { - // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthentication(waitingConn, securityMsg, connectionId) - } - case None => { - // Server - someone sent us something and we haven't authenticated yet - logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthentication(conn, securityMsg, connectionId) - } - } - return true - } else { - if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. - logError("message sent that is not security negotiation message on connection " + - "not authenticated yet, ignoring it!!") - return true - } - } - false - } - - private def handleMessage( - connectionManagerId: ConnectionManagerId, - message: Message, - connection: Connection) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (authEnabled) { - val res = handleAuthentication(connection, bufferMessage) - if (res) { - // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning") - return - } - } - if (bufferMessage.hasAckId()) { - messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status.success(message) - } - case None => { - /** - * We can fall down on this code because of following 2 cases - * - * (1) Invalid ack sent due to buggy code. - * - * (2) Late-arriving ack for a SendMessageStatus - * To avoid unwilling late-arriving ack - * caused by long pause like GC, you can set - * larger value than default to spark.core.connection.ack.wait.timeout - */ - logWarning(s"Could not find reference for received ack Message ${message.id}") - } - } - } - } else { - var ackMessage : Option[Message] = None - try { - ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - } catch { - case e: Exception => { - logError(s"Exception was thrown while processing message", e) - ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) - } - } finally { - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { - // see if we need to do sasl before writing - // this should only be the first negotiation as the Client!!! - if (!conn.isSaslComplete()) { - conn.synchronized { - if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) - var firstResponse: Array[Byte] = null - try { - firstResponse = conn.sparkSaslClient.firstToken() - val securityMsg = SecurityMessage.fromResponse(firstResponse, - conn.connectionId.toString()) - val message = securityMsg.toBufferMessage - if (message == null) throw new Exception("Error creating security message") - connectionsAwaitingSasl += ((conn.connectionId, conn)) - sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + - " to: " + connManagerId) - } catch { - case e: Exception => { - logError("Error getting first response from the SaslClient.", e) - conn.close() - throw new Exception("Error getting first response from the SaslClient") - } - } - } - } - } else { - logDebug("Sasl already established ") - } - } - - // allow us to add messages to the inbox for doing sasl negotiating - private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId, securityManager) - logInfo("creating new sending connection for security! " + newConnectionId ) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? - // We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - message.senderAddress = id.toSocketAddress() - logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") - val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - - // send security message until going connection has been authenticated - connection.send(message) - - wakeupSelector() - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, - connectionManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId, securityManager) - newConnection.onException { - case (conn, e) => { - logError("Exception while sending message.", e) - reportSendingMessageFailure(message.id, e) - } - } - logTrace("creating new sending connection: " + newConnectionId) - registerRequests.enqueue(newConnection) - - newConnection - } - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - - message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) - - if (authEnabled) { - try { - checkSendAuthFirst(connectionManagerId, connection) - } catch { - case NonFatal(e) => { - reportSendingMessageFailure(message.id, e) - } - } - } - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - wakeupSelector() - } - - private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(messageId) - s match { - case Some(msgStatus) => { - messageStatuses -= messageId - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.failure(e) - } - case None => { - logError("no messageStatus for failed message id: " + messageId) - } - } - } - } - - private def wakeupSelector() { - selector.wakeup() - } - - /** - * Send a message and block until an acknowledgment is received or an error occurs. - * @param connectionManagerId the message's destination - * @param message the message being sent - * @return a Future that either returns the acknowledgment message or captures an exception. - */ - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Message] = { - val promise = Promise[Message]() - - // It's important that the TimerTask doesn't capture a reference to `message`, which can cause - // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time - // at which they would originally be scheduled to run. Therefore, extract the message id - // from outside of the TimerTask closure (see SPARK-4393 for more context). - val messageId = message.id - // Keep a weak reference to the promise so that the completed promise may be garbage-collected - val promiseReference = new WeakReference(promise) - val timeoutTask: TimerTask = new TimerTask { - override def run(timeout: Timeout): Unit = { - messageStatuses.synchronized { - messageStatuses.remove(messageId).foreach { s => - val e = new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec") - val p = promiseReference.get - if (p != null) { - // Attempt to fail the promise with a Timeout exception - if (!p.tryFailure(e)) { - // If we reach here, then someone else has already signalled success or failure - // on this promise, so log a warning: - logError("Ignore error because promise is completed", e) - } - } else { - // The WeakReference was empty, which should never happen because - // sendMessageReliably's caller should have a strong reference to promise.future; - logError("Promise was garbage collected; this should never happen!", e) - } - } - } - } - } - - val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) - - val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTaskHandle.cancel() - s match { - case scala.util.Failure(e) => - // Indicates a failure where we either never sent or never got ACK'd - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - case scala.util.Success(ackMessage) => - if (ackMessage.hasError) { - val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head - val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) - errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, UTF_8) - val e = new IOException( - s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - } else { - if (!promise.trySuccess(ackMessage)) { - logWarning("Drop ackMessage because promise is completed") - } - } - } - }) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - - sendMessage(connectionManagerId, message) - promise.future - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - isActive = false - ackTimeoutMonitor.stop() - selector.close() - selectorThread.interrupt() - selectorThread.join() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - import scala.concurrent.ExecutionContext.Implicits.global - - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // scalastyle:off println - println("Received [" + msg + "] from [" + id + "]") - // scalastyle:on println - None - }) - - /* testSequentialSending(manager) */ - /* System.gc() */ - - /* testParallelSending(manager) */ - /* System.gc() */ - - /* testParallelDecreasingSending(manager) */ - /* System.gc() */ - - testContinuousSending(manager) - System.gc() - } - - // scalastyle:off println - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + - "(" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count) { i => - val bufferLen = size * (i + 1) - val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) - ByteBuffer.allocate(bufferLen).put(bufferContent) - } - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /* println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } - // scalastyle:on println -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala deleted file mode 100644 index 85d2fe2bf9c2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 - -import org.apache.spark.util.Utils - -private[nio] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString: String = { - this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" - } -} - - -private[nio] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId(): Int = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - /** - * Create a "negative acknowledgment" to notify a sender that an error occurred - * while processing its message. The exception's stacktrace will be formatted - * as a string, serialized into a byte array, and sent as the message payload. - */ - def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { - val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) - val errorMessage = createBufferMessage(serializedExceptionString, ackId) - errorMessage.hasError = true - errorMessage - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala deleted file mode 100644 index 7b3da4bb9d5e..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.{InetAddress, InetSocketAddress} -import java.nio.ByteBuffer - -private[nio] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString: String = { - "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - } - -} - - -private[nio] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala deleted file mode 100644 index b2aec160635c..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} - -import scala.concurrent.Future - - -/** - * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom - * implementation using Java NIO. - */ -final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) - extends BlockTransferService with Logging { - - private var cm: ConnectionManager = _ - - private var blockDataManager: BlockDataManager = _ - - /** - * Port number the service is listening on, available only after [[init]] is invoked. - */ - override def port: Int = { - checkInit() - cm.id.port - } - - /** - * Host name the service is listening on, available only after [[init]] is invoked. - */ - override def hostName: String = { - checkInit() - cm.id.host - } - - /** - * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. - */ - override def init(blockDataManager: BlockDataManager): Unit = { - this.blockDataManager = blockDataManager - cm = new ConnectionManager( - conf.getInt("spark.blockManager.port", 0), - conf, - securityManager, - "Connection manager for block manager") - cm.onReceiveMessage(onBlockMessageReceive) - } - - /** - * Tear down the transfer service. - */ - override def close(): Unit = { - if (cm != null) { - cm.stop() - } - } - - override def fetchBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener): Unit = { - checkInit() - - val cmId = new ConnectionManagerId(host, port) - val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => - BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) - }) - - val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - - // Register the listener on success/failure future callback. - future.onSuccess { case message => - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - - // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. - if (blockMessageArray.isEmpty) { - blockIds.foreach { id => - listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) - } - } else { - for (blockMessage: BlockMessage <- blockMessageArray) { - val msgType = blockMessage.getType - if (msgType != BlockMessage.TYPE_GOT_BLOCK) { - if (blockMessage.getId != null) { - listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message $msgType received from $cmId")) - } - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioManagedBuffer(blockMessage.getData)) - } - } - } - }(cm.futureExecContext) - - future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) - } - }(cm.futureExecContext) - } - - /** - * Upload a single block to a remote node, available only after [[init]] is invoked. - * - * This call blocks until the upload completes, or throws an exception upon failures. - */ - override def uploadBlock( - hostname: String, - port: Int, - execId: String, - blockId: BlockId, - blockData: ManagedBuffer, - level: StorageLevel) - : Future[Unit] = { - checkInit() - val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) - val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) - val remoteCmId = new ConnectionManagerId(hostName, port) - val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) - reply.map(x => ())(cm.futureExecContext) - } - - private def checkInit(): Unit = if (cm == null) { - throw new IllegalStateException(getClass.getName + " has not been initialized") - } - - private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => - logError("Exception handling buffer message", e) - Some(Message.createErrorMessage(e, msg.id)) - } - - case otherMessage: Any => - val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" - logError(errorMsg) - Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) - } - } - - private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => - val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + msg + "]") - putBlock(msg.id, msg.data, msg.level) - None - - case BlockMessage.TYPE_GET_BLOCK => - val msg = new GetBlock(blockMessage.getId) - logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) - - case _ => None - } - } - - private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) - logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(blockId: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId) - logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer.nioByteBuffer() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala deleted file mode 100644 index 232c552f9865..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -import org.apache.spark._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[nio] class SecurityMessage extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[nio] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2..7515aad09db7 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index aed035334442..48b943415317 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -17,13 +17,9 @@ package org.apache.spark.partial -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: * A Double value with error bars and associated confidence. */ -@Experimental class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { override def toString(): String = "[%.3f, %.3f]".format(low, high) } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 91b07ce3af1b..5afce75680f9 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag @@ -48,9 +48,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => - result(key) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -64,9 +64,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(key) = new BoundedDouble(mean, confidence, low, high) + result.put(key, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala index af26c3d59ac0..a16404068480 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub while (iter.hasNext) { val entry = iter.next() val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -72,9 +72,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub val confFactor = studentTCacher.get(counter.count) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala index 442fb86227d8..54a1beab3514 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl while (iter.hasNext) { val entry = iter.next() val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -80,9 +80,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl val confFactor = studentTCacher.get(counter.count) val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index 53c4b32c95ab..25cb7490aa9c 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -17,9 +17,6 @@ package org.apache.spark.partial -import org.apache.spark.annotation.Experimental - -@Experimental class PartialResult[R](initialVal: R, isFinal: Boolean) { private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None private var failure: Option[Exception] = None diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ca1eb1f4e4a9..14f541f937b4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,12 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,15 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - while (results.size < num && partsScanned < totalParts) { + val callSite = self.context.getCallSite + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -95,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap {_ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + numPartsToTry) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 1f755db48581..aedced7408cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -28,12 +28,13 @@ private[spark] class BinaryFileRDD[T]( inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 922030263756..fc1710fbad0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @@ -64,7 +64,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds */ private[spark] def removeBlocks() { blockIds.foreach { blockId => - sc.env.blockManager.master.removeBlock(blockId) + sparkContext.env.blockManager.master.removeBlock(blockId) } _isValid = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index c1d697178757..18e8cddbc40d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -27,8 +27,8 @@ import org.apache.spark.util.Utils private[spark] class CartesianPartition( idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], + @transient private val rdd1: RDD[_], + @transient private val rdd2: RDD[_], s1Index: Int, s2Index: Int ) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 72fe215dae73..b0364623af4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -29,7 +29,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * An RDD that recovers checkpointed data from storage. */ -private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) +private[spark] abstract class CheckpointRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { // CheckpointRDD should not be checkpointed again diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 130b58882d8e..3a0ca1d81329 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -22,11 +22,11 @@ import scala.language.existentials import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -70,12 +70,14 @@ private[spark] class CoGroupPartition( * * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of * instantiating this directly. - + * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output */ @DeveloperApi -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) +class CoGroupedRDD[K: ClassTag]( + @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], + part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs). @@ -126,8 +128,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -148,32 +148,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 926bce6f15a2..7fbaadcea3a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -74,10 +74,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } /** - * :: Experimental :: * Approximate operation to return the mean within a timeout. */ - @Experimental def meanApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { @@ -87,10 +85,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f1c17369cb48..f37c95bedc0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,14 +44,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -88,8 +88,8 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed - * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD * creates. * @param inputFormatClass Storage format of the data to be read. @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -109,7 +109,7 @@ class HadoopRDD[K, V]( extends RDD[(K, V)](sc, Nil) with Logging { if (initLocalJobConfFuncOpt.isDefined) { - sc.clean(initLocalJobConfFuncOpt.get) + sparkContext.clean(initLocalJobConfFuncOpt.get) } def this( @@ -123,7 +123,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableConfiguration(conf)) .asInstanceOf[Broadcast[SerializableConfiguration]], - None /* initLocalJobConfFuncOpt */, + initLocalJobConfFuncOpt = None, inputFormatClass, keyClass, valueClass, @@ -137,7 +137,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { @@ -182,17 +182,12 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] - if (newInputFormat.isInstanceOf[Configurable]) { - newInputFormat.asInstanceOf[Configurable].setConf(conf) + newInputFormat match { + case c: Configurable => c.setConf(conf) + case _ => } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } @@ -201,9 +196,6 @@ class HadoopRDD[K, V]( // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(jobConf) - } val inputSplits = inputFormat.getSplits(jobConf, minPartitions) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { @@ -221,6 +213,12 @@ class HadoopRDD[K, V]( val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.inputSplit.value match { + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -257,8 +255,22 @@ class HadoopRDD[K, V]( } override def close() { - try { - reader.close() + if (reader != null) { + SqlNewHadoopRDDState.unsetInputFileName() + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { + reader.close() + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } finally { + reader = null + } if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || @@ -272,12 +284,6 @@ class HadoopRDD[K, V]( logWarning("Unable to get input size to set InputMetrics for task", e) } } - } catch { - case e: Exception => { - if (!Utils.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala index daa5779d688c..bfe19195fcd3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.storage.RDDBlockId * @param numPartitions the number of partitions in the checkpointed RDD */ private[spark] class LocalCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, rddId: Int, numPartitions: Int) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index d6fad896845f..c115e0ff74d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * is written to the local, ephemeral block storage that lives in each executor. This is useful * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). */ -private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index a838aac6e8d1..4312d3a41775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -21,6 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +/** + * An RDD that applies the provided function to every partition of the parent RDD. + */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f83a051f5da1..86f38ae836b2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -28,23 +28,21 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.input.WholeTextFileInputFormat import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -60,7 +58,6 @@ private[spark] class NewHadoopPartition( * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. - * @param conf The Hadoop configuration. */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -68,14 +65,14 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient private val _conf: Configuration) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) - // private val serializableConf = new SerializableWritable(conf) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) + // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -84,14 +81,35 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -104,7 +122,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -120,14 +138,14 @@ class NewHadoopRDD[K, V]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -164,30 +182,32 @@ class NewHadoopRDD[K, V]( } private def close() { - try { - if (reader != null) { - // Close reader and release it + if (reader != null) { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { reader.close() - reader = null - - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) } - } + } finally { + reader = null } - } catch { - case e: Exception => { - if (!Utils.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) } } } @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -256,31 +280,3 @@ private[spark] object NewHadoopRDD { } } } - -private[spark] class WholeTextFileRDD( - sc : SparkContext, - inputFormatClass: Class[_ <: WholeTextFileInputFormat], - keyClass: Class[String], - valueClass: Class[String], - @transient conf: Configuration, - minPartitions: Int) - extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { - - override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = newJobContext(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } -} - diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 326fafb230a4..44d195587a08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -22,7 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} import scala.collection.{Map, mutable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.DynamicVariable @@ -57,25 +57,29 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) with SparkHadoopMapReduceUtil with Serializable { + /** + * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: V => C, + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = self.withScope { + serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -103,13 +107,50 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializer: Serializer = null): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + partitioner, mapSideCombine, serializer)(null) + } + + /** + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + * This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] */ - def combineByKey[C](createCombiner: V => C, + def combineByKey[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, numPartitions: Int): RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, numPartitions)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + */ + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + new HashPartitioner(numPartitions)) } /** @@ -133,7 +174,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) - combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) + combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), + cleanedSeqOp, combOp, partitioner) } /** @@ -182,7 +224,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) - combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) + combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), + cleanedFunc, cleanedFunc, partitioner) } /** @@ -231,7 +274,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -246,7 +288,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param seed seed for the random number generator * @return RDD containing the sampled subset */ - @Experimental def sampleByKeyExact( withReplacement: Boolean, fractions: Map[K, Double], @@ -268,7 +309,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { - combineByKey[V]((v: V) => v, func, func, partitioner) + combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** @@ -312,14 +353,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } : Iterator[JHashMap[K, V]] val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { - m2.foreach { pair => + m2.asScala.foreach { pair => val old = m1.get(pair._1) m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] - self.mapPartitions(reducePartition).reduce(mergeMaps) + self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } /** Alias for reduceByKeyLocally */ @@ -341,19 +382,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[Map[K, BoundedDouble]] = self.withScope { self.map(_._1).countByValueApprox(timeout, confidence) } /** - * :: Experimental :: - * * Return approximate number of distinct values for each key in this RDD. * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: @@ -370,7 +407,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * If `sp` equals 0, the sparse representation is skipped. * @param partitioner Partitioner to use for the resulting RDD. */ - @Experimental def countApproxDistinctByKey( p: Int, sp: Int, @@ -392,7 +428,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) h1 } - combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.cardinality()) + combineByKeyWithClassTag(createHLL, mergeValueHLL, mergeHLL, partitioner) + .mapValues(_.cardinality()) } /** @@ -466,7 +503,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createCombiner = (v: V) => CompactBuffer(v) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -565,12 +602,30 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. This method is here for backward compatibility. It + * does not provide combiner classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } /** @@ -934,8 +989,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -955,6 +1011,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -967,10 +1028,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -984,6 +1042,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -994,6 +1065,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1002,7 +1078,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1051,6 +1128,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1065,7 +1156,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1093,7 +1183,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e2394e28f8d2..582fa93afe34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -83,8 +83,8 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } private[spark] class ParallelCollectionRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient data: Seq[T], + sc: SparkContext, + @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index a00f4c1cdff9..0c6ddda52cee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -32,7 +32,7 @@ private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Par * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -55,8 +55,8 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF */ @DeveloperApi class PartitionPruningRDD[T: ClassTag]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) + prev: RDD[T], + partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Partition, context: TaskContext): Iterator[T] = { @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index a637d6f15b7e..3b1acacf409b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -47,8 +47,8 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient preservesPartitioning: Boolean, - @transient seed: Long = Utils.random.nextLong) + preservesPartitioning: Boolean, + @transient private val seed: Long = Utils.random.nextLong) extends RDD[U](prev) { @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 3bb9998e1db4..afbe566b7656 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -23,7 +23,7 @@ import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -72,7 +72,7 @@ private[spark] class PipedRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) + val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -81,7 +81,7 @@ private[spark] class PipedRDD[T: ClassTag]( // so the user code can access the input filename if (split.isInstanceOf[HadoopPartition]) { val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) } // When spark.worker.separated.working.directory option is turned on, each diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 081c721f2368..9fe9d83a705b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -242,6 +242,12 @@ abstract class RDD[T: ClassTag]( } } + /** + * Returns the number of partitions of this RDD. + */ + @Since("1.6.0") + final def getNumPartitions: Int = partitions.length + /** * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. @@ -294,7 +300,11 @@ abstract class RDD[T: ClassTag]( */ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) + if (isCheckpointedAndMaterialized) { + firstParent[T].iterator(split, context) + } else { + compute(split, context) + } } /** @@ -469,50 +479,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** @@ -707,6 +711,24 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } + /** + * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a + * performance API to be used carefully only if we are sure that the RDD elements are + * serializable and don't require closure cleaning. + * + * @param preservesPartitioning indicates whether the input function preserves the partitioner, + * which should be `false` unless this is a pair RDD and the input function doesn't modify + * the keys. + */ + private[spark] def mapPartitionsInternal[U: ClassTag]( + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter), + preservesPartitioning) + } + /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. @@ -1121,11 +1143,9 @@ abstract class RDD[T: ClassTag]( def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope { @@ -1154,10 +1174,8 @@ abstract class RDD[T: ClassTag]( } /** - * :: Experimental :: * Approximate version of countByValue(). */ - @Experimental def countByValueApprox(timeout: Long, confidence: Double = 0.95) (implicit ord: Ordering[T] = null) : PartialResult[Map[T, BoundedDouble]] = withScope { @@ -1176,7 +1194,6 @@ abstract class RDD[T: ClassTag]( } /** - * :: Experimental :: * Return approximate number of distinct elements in the RDD. * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: @@ -1192,7 +1209,6 @@ abstract class RDD[T: ClassTag]( * @param sp The precision value for the sparse set, between 0 and 32. * If `sp` equals 0, the sparse representation is skipped. */ - @Experimental def countApproxDistinct(p: Int, sp: Int): Long = withScope { require(p >= 4, s"p ($p) must be >= 4") require(sp <= 32, s"sp ($sp) must be <= 32") @@ -1317,7 +1333,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the top k (largest) elements from this RDD as defined by the specified - * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: + * implicit Ordering[T] and maintains the ordering. This does the opposite of + * [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) * // returns Array(12) @@ -1526,20 +1543,37 @@ abstract class RDD[T: ClassTag]( persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) } - checkpointData match { - case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( - "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") - case _ => + // If this RDD is already checkpointed and materialized, its lineage is already truncated. + // We must not override our `checkpointData` in this case because it is needed to recover + // the checkpointed data. If it is overridden, next time materializing on this RDD will + // cause error. + if (isCheckpointedAndMaterialized) { + logWarning("Not marking RDD for local checkpoint because it was already " + + "checkpointed and materialized") + } else { + // Lineage is not truncated yet, so just override any existing checkpoint data with ours + checkpointData match { + case Some(_: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) } - checkpointData = Some(new LocalRDDCheckpointData(this)) this } /** - * Return whether this RDD is marked for checkpointing, either reliably or locally. + * Return whether this RDD is checkpointed and materialized, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + /** + * Return whether this RDD is checkpointed and materialized, either reliably or locally. + * This is introduced as an alias for `isCheckpointed` to clarify the semantics of the + * return value. Exposed for testing. + */ + private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed + /** * Return whether this RDD is marked for local checkpointing. * Exposed for testing. @@ -1666,7 +1700,7 @@ abstract class RDD[T: ClassTag]( import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" - val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + val storageInfo = rdd.context.getRDDStorageInfo(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), bytesToString(info.externalBlockStoreSize), bytesToString(info.diskSize))) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 0e43520870c0..429514b4f6be 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -36,7 +36,7 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 44667281c106..540cbd688b63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOr import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.google.common.base.Objects import org.apache.spark.{Logging, SparkContext} @@ -67,6 +68,8 @@ private[spark] class RDDOperationScope( } } + override def hashCode(): Int = Objects.hashCode(id, name, parent) + override def toString: String = toJson } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 35d8b0bfd18c..fa71b8c26233 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,21 +20,22 @@ package org.apache.spark.rdd import java.io.IOException import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} /** * An RDD that reads from checkpoint files previously written to reliable storage. */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, - val checkpointPath: String) - extends CheckpointRDD[T](sc) { + sc: SparkContext, + val checkpointPath: String, + _partitioner: Option[Partitioner] = None + ) extends CheckpointRDD[T](sc) { @transient private val hadoopConf = sc.hadoopConfiguration @transient private val cpath = new Path(checkpointPath) @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( /** * Return the path of the checkpoint directory this RDD reads data from. */ - override def getCheckpointFile: Option[String] = Some(checkpointPath) + override val getCheckpointFile: Option[String] = Some(checkpointPath) + + override val partitioner: Option[Partitioner] = { + _partitioner.orElse { + ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath) + } + } /** * Return partitions described by the files in the checkpoint directory. @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging { "part-%05d".format(partitionIndex) } + private def checkpointPartitionerFileName(): String = { + "_partitioner" + } + + /** + * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD. + */ + def writeRDDToCheckpointDirectory[T: ClassTag]( + originalRDD: RDD[T], + checkpointDir: String, + blockSize: Int = -1): ReliableCheckpointRDD[T] = { + + val sc = originalRDD.sparkContext + + // Create the output path for the checkpoint + val checkpointDirPath = new Path(checkpointDir) + val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration) + if (!fs.mkdirs(checkpointDirPath)) { + throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + sc.runJob(originalRDD, + writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _) + + if (originalRDD.partitioner.nonEmpty) { + writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) + } + + val newRDD = new ReliableCheckpointRDD[T]( + sc, checkpointDirPath.toString, originalRDD.partitioner) + if (newRDD.partitions.length != originalRDD.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + } + newRDD + } + /** - * Write this partition's values to a checkpoint file. + * Write a RDD partition's data to a checkpoint file. */ - def writeCheckpointFile[T: ClassTag]( + def writePartitionToCheckpointFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { @@ -144,8 +193,71 @@ private[spark] object ReliableCheckpointRDD extends Logging { } else { // Some other copy of this task must've finished before us and renamed it logInfo(s"Final output path $finalOutputPath already exists; not overwriting it") - fs.delete(tempOutputPath, false) + if (!fs.delete(tempOutputPath, false)) { + logWarning(s"Error deleting ${tempOutputPath}") + } + } + } + } + + /** + * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort + * basis; any exception while writing the partitioner is caught, logged and ignored. + */ + private def writePartitionerToCheckpointDir( + sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = { + try { + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeObject(partitioner) + } { + serializeStream.close() + } + logDebug(s"Written partitioner to $partitionerFilePath") + } catch { + case NonFatal(e) => + logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath") + } + } + + + /** + * Read a partitioner from the given RDD checkpoint directory, if it exists. + * This is done on a best-effort basis; any exception while reading the partitioner is + * caught, logged and ignored. + */ + private def readCheckpointedPartitionerFile( + sc: SparkContext, + checkpointDirPath: String): Option[Partitioner] = { + try { + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(partitionerFilePath)) { + val fileInputStream = fs.open(partitionerFilePath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + val partitioner = Utils.tryWithSafeFinally[Partitioner] { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } + logDebug(s"Read partitioner from $partitionerFilePath") + Some(partitioner) + } else { + logDebug("No partitioner file") + None } + } catch { + case NonFatal(e) => + logWarning(s"Error reading partitioner from $checkpointDirPath, " + + s"partitioner will not be recovered which may lead to performance loss", e) + None } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 1df8eef5ff2b..cac6cbe780e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.SerializableConfiguration * An implementation of checkpointing that writes the RDD data to reliable storage. * This allows drivers to be restarted on failure with previously computed state. */ -private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { // The directory to which the associated RDD has been checkpointed to @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[ * This is called immediately after the first action invoked on this RDD has completed. */ protected override def doCheckpoint(): CheckpointRDD[T] = { - - // Create the output path for the checkpoint - val path = new Path(cpDir) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $cpDir") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) - val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $rdd(${rdd.partitions.length})") - } + val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir) // Optionally clean our checkpoint files if the reference is out of scope if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { @@ -83,13 +65,12 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[ } logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") - newRDD } } -private[spark] object ReliableRDDCheckpointData { +private[spark] object ReliableRDDCheckpointData extends Logging { /** Return the path of the directory to which this RDD's checkpoint data is written. */ def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = { @@ -101,7 +82,9 @@ private[spark] object ReliableRDDCheckpointData { checkpointPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) if (fs.exists(path)) { - fs.delete(path, true) + if (!fs.delete(path, true)) { + logWarning(s"Error deleting ${path.toString()}") + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 2dc47f95937c..3ef506e1562b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -37,7 +39,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { */ // TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C]( +class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { @@ -84,6 +86,12 @@ class ShuffledRDD[K, V, C]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } + override protected def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) + } + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala new file mode 100644 index 000000000000..3f15fff79366 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. + * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + */ +private[spark] object SqlNewHadoopRDDState { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index f7cb1791d4ac..25ec685eff5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -63,15 +63,17 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => + def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) + : Dependency[_] = { if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part, serializer) } } + Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) } override def getPartitions: Array[Partition] = { @@ -105,7 +107,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = { dependencies(depNum) match { case oneToOneDependency: OneToOneDependency[_] => val dependencyPartition = partition.narrowDeps(depNum).get.split @@ -125,7 +127,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 3986645350a8..66cf4369da2e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class UnionPartition[T: ClassTag]( idx: Int, - @transient rdd: RDD[T], + @transient private val rdd: RDD[T], val parentRddIndex: Int, - @transient parentRddPartitionIndex: Int) + @transient private val parentRddPartitionIndex: Int) extends Partition { var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala new file mode 100644 index 000000000000..e3f14fe7ef0f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.{Text, Writable} +import org.apache.hadoop.mapreduce.InputSplit + +import org.apache.spark.{Partition, SparkContext} +import org.apache.spark.input.WholeTextFileInputFormat + +/** + * An RDD that reads a bunch of text files in, and each text file becomes one record. + */ +private[spark] class WholeTextFileRDD( + sc : SparkContext, + inputFormatClass: Class[_ <: WholeTextFileInputFormat], + keyClass: Class[Text], + valueClass: Class[Text], + conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + val conf = getConf + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 81f40ad33aa5..4333a679c8aa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]], + @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e277ae28d588..32931d59acb1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -37,7 +37,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) * @tparam T parent RDD item type */ private[spark] -class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala new file mode 100644 index 000000000000..eb0b26947f50 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.util.Utils + + +/** + * Address for an RPC environment, with hostname and port. + */ +private[spark] case class RpcAddress(host: String, port: Int) { + + def hostPort: String = host + ":" + port + + /** Returns a string in the form of "spark://host:port". */ + def toSparkURL: String = "spark://" + hostPort + + override def toString: String = hostPort +} + + +private[spark] object RpcAddress { + + /** Return the [[RpcAddress]] represented by `uri`. */ + def fromURIString(uri: String): RpcAddress = { + val uriObj = new java.net.URI(uri) + RpcAddress(uriObj.getHost, uriObj.getPort) + } + + /** Returns the [[RpcAddress]] encoded in the form of "spark://host:port" */ + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 3e5b64265e91..f527ec86ab7b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext { /** * The sender of this message. */ - def sender: RpcEndpointRef + def senderAddress: RpcAddress } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index dfcbc51cdf61..0ba95169529e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -28,20 +28,6 @@ private[spark] trait RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv } -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint - - /** * An end point for the RPC that defines what functions to trigger given a message. * @@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint { } /** - * Invoked before [[RpcEndpoint]] starts to handle any message. + * Invoked when `remoteAddress` is connected to the current node. */ - def onStart(): Unit = { + def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when [[RpcEndpoint]] is stopping. + * Invoked when `remoteAddress` is lost. */ - def onStop(): Unit = { + def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is connected to the current node. + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. */ - def onConnected(remoteAddress: RpcAddress): Unit = { + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is lost. + * Invoked before [[RpcEndpoint]] starts to handle any message. */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { + def onStart(): Unit = { // By default, do nothing. } /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. + * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot + * use it to send or ask messages. */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + def onStop(): Unit = { // By default, do nothing. } @@ -146,3 +133,16 @@ private[spark] trait RpcEndpoint { } } } + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala new file mode 100644 index 000000000000..d177881fb305 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +private[rpc] class RpcEndpointNotFoundException(uri: String) + extends SparkException(s"Cannot find endpoint: $uri") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 6ae47894598b..623da3e9c11b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) +private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging { private[this] val maxRetries = RpcUtils.numRetries(conf) @@ -67,7 +67,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this * method retries, the message handling in the receiver side should be idempotent. * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. * * @param message the message to send @@ -82,7 +82,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method * retries, the message handling in the receiver side should be idempotent. * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. * * @param message the message to send @@ -100,7 +100,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) val future = ask[T](message, timeout) val result = timeout.awaitResult(future) if (result == null) { - throw new SparkException("Actor returned null") + throw new SparkException("RpcEndpoint returned null") } return result } catch { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 29debe808130..64a4a8bf7c5e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,12 +17,10 @@ package org.apache.spark.rpc -import java.net.URI -import java.util.concurrent.TimeoutException +import java.io.File +import java.nio.channels.ReadableByteChannel -import scala.concurrent.{Awaitable, Await, Future} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} @@ -35,9 +33,10 @@ import org.apache.spark.util.{RpcUtils, Utils} private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvNames = Map( + "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } @@ -47,12 +46,12 @@ private[spark] object RpcEnv { host: String, port: Int, conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { + securityManager: SecurityManager, + clientMode: Boolean = false): RpcEnv = { // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) + val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) getRpcEnvFactory(conf).create(config) } - } @@ -98,15 +97,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } - /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` - * asynchronously. - */ - def asyncSetupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) - } - /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. * This is a blocking action. @@ -145,153 +135,74 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T -} - - -private[spark] case class RpcEnvConfig( - conf: SparkConf, - name: String, - host: String, - port: Int, - securityManager: SecurityManager) - - -/** - * Represents a host and port. - */ -private[spark] case class RpcAddress(host: String, port: Int) { - // TODO do we need to add the type of RpcEnv in the address? - - val hostPort: String = host + ":" + port - - override val toString: String = hostPort - - def toSparkURL: String = "spark://" + hostPort -} - - -private[spark] object RpcAddress { /** - * Return the [[RpcAddress]] represented by `uri`. + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. */ - def fromURI(uri: URI): RpcAddress = { - RpcAddress(uri.getHost, uri.getPort) - } + def fileServer: RpcEnvFileServer /** - * Return the [[RpcAddress]] represented by `uri`. + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. */ - def fromURIString(uri: String): RpcAddress = { - fromURI(new java.net.URI(uri)) - } + def openChannel(uri: String): ReadableByteChannel - def fromSparkURL(sparkUrl: String): RpcAddress = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - RpcAddress(host, port) - } } - /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. - */ -private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) - extends TimeoutException(message) { initCause(cause) } - - -/** - * Associates a timeout with a description so that a when a TimeoutException occurs, additional - * context about the timeout can be amended to the exception message. - * @param duration timeout duration in seconds - * @param timeoutProp the configuration property that controls this timeout + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. */ -private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) - extends Serializable { - - /** Amends the standard message of TimeoutException to include the description */ - private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { - new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) - } +private[spark] trait RpcEnvFileServer { /** - * PartialFunction to match a TimeoutException and add the timeout description to the message + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. * - * @note This can be used in the recover callback of a Future to add to a TimeoutException - * Example: - * val timeout = new RpcTimeout(5 millis, "short timeout") - * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + * @param file Local file to serve. + * @return A URI for the location of the file. */ - def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { - // The exception has already been converted to a RpcTimeoutException so just raise it - case rte: RpcTimeoutException => throw rte - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - } - - /** - * Wait for the completed result and return it. If the result is not available within this - * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` - * is still not ready - */ - def awaitResult[T](awaitable: Awaitable[T]): T = { - try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout - } -} - - -private[spark] object RpcTimeout { + def addFile(file: File): String /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @throws NoSuchElementException if property is not set + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. */ - def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } - new RpcTimeout(timeout, timeoutProp) - } + def addJar(file: File): String /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @param defaultValue default timeout value in seconds if property not found - */ - def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } - new RpcTimeout(timeout, timeoutProp) + * Adds a local directory to be served via this file server. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectory(baseUri: String, path: File): String + + /** Validates and normalizes the base URI for directories. */ + protected def validateDirectoryUri(baseUri: String): String = { + val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/") + require(fixedBaseUri != "/files" && fixedBaseUri != "/jars", + "Directory URI cannot be /files nor /jars.") + fixedBaseUri } - /** - * Lookup prioritized list of timeout properties in the configuration - * and create a RpcTimeout with the first set property key in the - * description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutPropList prioritized list of property keys for the timeout in seconds - * @param defaultValue default timeout value in seconds if no properties found - */ - def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { - require(timeoutPropList.nonEmpty) - - // Find the first set property or use the default value with the first property - val itr = timeoutPropList.iterator - var foundProp: Option[(String, String)] = None - while (itr.hasNext && foundProp.isEmpty){ - val propKey = itr.next() - conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } - } - val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) - val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } - new RpcTimeout(timeout, finalProp._1) - } } + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager, + clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala new file mode 100644 index 000000000000..285786ebf9f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Awaitable, Await} +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index fc17542abf81..9d098154f719 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -39,13 +41,12 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and * remove Akka from the dependencies. - * - * @param actorSystem - * @param conf - * @param boundPort */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -68,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -87,9 +90,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use lazy because the Actor needs to use `endpointRef`. + // Use defered function because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. - lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) @@ -166,9 +169,9 @@ private[spark] class AkkaRpcEnv private[akka] ( _sender ! AkkaMessage(response, false) } - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + // Use "lazy" because most of RpcEndpoints don't need "senderAddress" + override lazy val senderAddress: RpcAddress = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address }) } else { endpoint.receive @@ -227,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -245,6 +249,57 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -253,7 +308,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } @@ -267,18 +322,25 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause) } } private[akka] class AkkaRpcEndpointRef( - @transient defaultAddress: RpcAddress, - @transient _actorRef: => ActorRef, - @transient conf: SparkConf, - @transient initInConstructor: Boolean = true) + @transient private val defaultAddress: RpcAddress, + @transient private val _actorRef: () => ActorRef, + conf: SparkConf, + initInConstructor: Boolean) extends RpcEndpointRef(conf) with Logging { - lazy val actorRef = _actorRef + def this( + defaultAddress: RpcAddress, + _actorRef: ActorRef, + conf: SparkConf) = { + this(defaultAddress, () => _actorRef, conf, true) + } + + lazy val actorRef = _actorRef() override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala new file mode 100644 index 000000000000..533c9847661b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc._ +import org.apache.spark.util.ThreadUtils + +/** + * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + */ +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { + + private class EndpointData( + val name: String, + val endpoint: RpcEndpoint, + val ref: NettyRpcEndpointRef) { + val inbox = new Inbox(ref, endpoint) + } + + private val endpoints = new ConcurrentHashMap[String, EndpointData] + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] + + // Track the receivers whose inboxes may contain messages. + private val receivers = new LinkedBlockingQueue[EndpointData] + + /** + * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced + * immediately. + */ + @GuardedBy("this") + private var stopped = false + + def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { + val addr = RpcEndpointAddress(nettyEnv.address, name) + val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) + synchronized { + if (stopped) { + throw new IllegalStateException("RpcEnv has been stopped") + } + if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { + throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + } + val data = endpoints.get(name) + endpointRefs.put(data.endpoint, data.ref) + receivers.offer(data) // for the OnStart message + } + endpointRef + } + + def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint) + + def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint) + + // Should be idempotent + private def unregisterRpcEndpoint(name: String): Unit = { + val data = endpoints.remove(name) + if (data != null) { + data.inbox.stop() + receivers.offer(data) // for the OnStop message + } + // Don't clean `endpointRefs` here because it's possible that some messages are being processed + // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via + // `removeRpcEndpointRef`. + } + + def stop(rpcEndpointRef: RpcEndpointRef): Unit = { + synchronized { + if (stopped) { + // This endpoint will be stopped by Dispatcher.stop() method. + return + } + unregisterRpcEndpoint(rpcEndpointRef.name) + } + } + + /** + * Send a message to all registered [[RpcEndpoint]]s in this process. + * + * This can be used to make network events known to all end points (e.g. "a new node connected"). + */ + def postToAll(message: InboxMessage): Unit = { + val iter = endpoints.keySet().iterator() + while (iter.hasNext) { + val name = iter.next + postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) + } + } + + /** Posts a message sent by a remote endpoint. */ + def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + val rpcCallContext = + new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) + } + + /** Posts a message sent by a local endpoint. */ + def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { + val rpcCallContext = + new LocalNettyRpcCallContext(message.senderAddress, p) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) + } + + /** Posts a one-way message. */ + def postOneWayMessage(message: RequestMessage): Unit = { + postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), + (e) => throw e) + } + + /** + * Posts a message to a specific endpoint. + * + * @param endpointName name of the endpoint. + * @param createMessageFn function to create the message. + * @param callbackIfStopped callback function if the endpoint is stopped. + */ + private def postMessage( + endpointName: String, + message: InboxMessage, + callbackIfStopped: (Exception) => Unit): Unit = { + val shouldCallOnStop = synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(message) + receivers.offer(data) + false + } + } + if (shouldCallOnStop) { + // We don't need to call `onStop` in the `synchronized` block + val error = if (stopped) { + new IllegalStateException("RpcEnv already stopped.") + } else { + new SparkException(s"Could not find $endpointName or it has been stopped.") + } + callbackIfStopped(error) + } + } + + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + } + // Stop all endpoints. This will queue all endpoints for processing by the message loops. + endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) + // Enqueue a message that tells the message loops to stop. + receivers.offer(PoisonPill) + threadpool.shutdown() + } + + def awaitTermination(): Unit = { + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + /** + * Return if the endpoint exists + */ + def verify(name: String): Boolean = { + endpoints.containsKey(name) + } + + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", + Runtime.getRuntime.availableProcessors()) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = receivers.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + receivers.offer(PoisonPill) + return + } + data.inbox.process(Dispatcher.this) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new EndpointData(null, null, null) +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala new file mode 100644 index 000000000000..175463cc1031 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + + +private[netty] sealed trait InboxMessage + +private[netty] case class OneWayMessage( + senderAddress: RpcAddress, + content: Any) extends InboxMessage + +private[netty] case class RpcMessage( + senderAddress: RpcAddress, + content: Any, + context: NettyRpcCallContext) extends InboxMessage + +private[netty] case object OnStart extends InboxMessage + +private[netty] case object OnStop extends InboxMessage + +/** A message to tell all endpoints that a remote process has connected. */ +private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage + +/** A message to tell all endpoints that a remote process has disconnected. */ +private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage + +/** A message to tell all endpoints that a network error has happened. */ +private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress) + extends InboxMessage + +/** + * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. + */ +private[netty] class Inbox( + val endpointRef: NettyRpcEndpointRef, + val endpoint: RpcEndpoint) + extends Logging { + + inbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + protected val messages = new java.util.LinkedList[InboxMessage]() + + /** True if the inbox (and its associated endpoint) is stopped. */ + @GuardedBy("this") + private var stopped = false + + /** Allow multiple threads to process messages at the same time. */ + @GuardedBy("this") + private var enableConcurrent = false + + /** The number of threads processing messages for this inbox. */ + @GuardedBy("this") + private var numActiveThreads = 0 + + // OnStart should be the first message to process + inbox.synchronized { + messages.add(OnStart) + } + + /** + * Process stored messages. + */ + def process(dispatcher: Dispatcher): Unit = { + var message: InboxMessage = null + inbox.synchronized { + if (!enableConcurrent && numActiveThreads != 0) { + return + } + message = messages.poll() + if (message != null) { + numActiveThreads += 1 + } else { + return + } + } + while (true) { + safelyCall(endpoint) { + message match { + case RpcMessage(_sender, content, context) => + try { + endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + context.sendFailure(e) + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e + } + + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + + case OnStart => + endpoint.onStart() + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + inbox.synchronized { + if (!stopped) { + enableConcurrent = true + } + } + } + + case OnStop => + val activeThreads = inbox.synchronized { inbox.numActiveThreads } + assert(activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") + dispatcher.removeRpcEndpointRef(endpoint) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + + case RemoteProcessConnected(remoteAddress) => + endpoint.onConnected(remoteAddress) + + case RemoteProcessDisconnected(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + + case RemoteProcessConnectionError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } + + inbox.synchronized { + // "enableConcurrent" will be set to false after `onStop` is called, so we should check it + // every time. + if (!enableConcurrent && numActiveThreads != 1) { + // If we are not the only one worker, exit + numActiveThreads -= 1 + return + } + message = messages.poll() + if (message == null) { + numActiveThreads -= 1 + return + } + } + } + } + + def post(message: InboxMessage): Unit = inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages + onDrop(message) + } else { + messages.add(message) + false + } + } + + def stop(): Unit = inbox.synchronized { + // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last + // message + if (!stopped) { + // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only + // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources + // safely. + enableConcurrent = false + stopped = true + messages.add(OnStop) + // Note: The concurrent events in messages will be processed one by one. + } + } + + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + /** + * Called when we are dropping a message. Test cases override this to test message dropping. + * Exposed for testing. + */ + protected def onDrop(message: InboxMessage): Unit = { + logWarning(s"Drop $message because $endpointRef is stopped") + } + + /** + * Calls action closure, and calls the endpoint's onError function in the case of exceptions. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try action catch { + case NonFatal(e) => + try endpoint.onError(e) catch { + case NonFatal(ee) => logError(s"Ignoring error", ee) + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala new file mode 100644 index 000000000000..6637e2321f67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import scala.concurrent.Promise + +import org.apache.spark.Logging +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc.{RpcAddress, RpcCallContext} + +private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) + extends RpcCallContext with Logging { + + protected def send(message: Any): Unit + + override def reply(response: Any): Unit = { + send(response) + } + + override def sendFailure(e: Throwable): Unit = { + send(RpcFailure(e)) + } + +} + +/** + * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. + */ +private[netty] class LocalNettyRpcCallContext( + senderAddress: RpcAddress, + p: Promise[Any]) + extends NettyRpcCallContext(senderAddress) { + + override protected def send(message: Any): Unit = { + p.success(message) + } +} + +/** + * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back. + */ +private[netty] class RemoteNettyRpcCallContext( + nettyEnv: NettyRpcEnv, + callback: RpcResponseCallback, + senderAddress: RpcAddress) + extends NettyRpcCallContext(senderAddress) { + + override protected def send(message: Any): Unit = { + val reply = nettyEnv.serialize(message) + callback.onSuccess(reply) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala new file mode 100644 index 000000000000..de3db6ba624f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -0,0 +1,636 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io._ +import java.lang.{Boolean => JBoolean} +import java.net.{InetSocketAddress, URI} +import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.Nullable + +import scala.concurrent.{Future, Promise} +import scala.reflect.ClassTag +import scala.util.{DynamicVariable, Failure, Success, Try} +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.client._ +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server._ +import org.apache.spark.rpc._ +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} +import org.apache.spark.util.{ThreadUtils, Utils} + +private[netty] class NettyRpcEnv( + val conf: SparkConf, + javaSerializerInstance: JavaSerializerInstance, + host: String, + securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + + private[netty] val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", + conf.getInt("spark.rpc.io.threads", 0)) + + private val dispatcher: Dispatcher = new Dispatcher(this) + + private val streamManager = new NettyStreamManager(this) + + private val transportContext = new TransportContext(transportConf, + new NettyRpcHandler(dispatcher, this, streamManager)) + + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } + } + + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + + // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool + // to implement non-blocking send/ask. + // TODO: a non-blocking TransportClientFactory.createClient in future + private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + "netty-rpc-connection", + conf.getInt("spark.rpc.connect.threads", 64)) + + @volatile private var server: TransportServer = _ + + private val stopped = new AtomicBoolean(false) + + /** + * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], + * we just put messages to its [[Outbox]] to implement a non-blocking `send` method. + */ + private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]() + + /** + * Remove the address's Outbox and stop it. + */ + private[netty] def removeOutbox(address: RpcAddress): Unit = { + val outbox = outboxes.remove(address) + if (outbox != null) { + outbox.stop() + } + } + + def startServer(port: Int): Unit = { + val bootstraps: java.util.List[TransportServerBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + } else { + java.util.Collections.emptyList() + } + server = transportContext.createServer(host, port, bootstraps) + dispatcher.registerRpcEndpoint( + RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) + } + + @Nullable + override lazy val address: RpcAddress = { + if (server != null) RpcAddress(host, server.getPort()) else null + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.registerRpcEndpoint(name, endpoint) + } + + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + val addr = RpcEndpointAddress(uri) + val endpointRef = new NettyRpcEndpointRef(conf, addr, this) + val verifier = new NettyRpcEndpointRef( + conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this) + verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => + if (find) { + Future.successful(endpointRef) + } else { + Future.failed(new RpcEndpointNotFoundException(uri)) + } + }(ThreadUtils.sameThread) + } + + override def stop(endpointRef: RpcEndpointRef): Unit = { + require(endpointRef.isInstanceOf[NettyRpcEndpointRef]) + dispatcher.stop(endpointRef) + } + + private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { + if (receiver.client != null) { + message.sendWith(receiver.client) + } else { + require(receiver.address != null, + "Cannot send message to client endpoint with no listen address.") + val targetOutbox = { + val outbox = outboxes.get(receiver.address) + if (outbox == null) { + val newOutbox = new Outbox(this, receiver.address) + val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } + } else { + outbox + } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(receiver.address) + targetOutbox.stop() + } else { + targetOutbox.send(message) + } + } + } + + private[netty] def send(message: RequestMessage): Unit = { + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + // Message to a local RPC endpoint. + dispatcher.postOneWayMessage(message) + } else { + // Message to a remote RPC endpoint. + postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) + } + } + + private[netty] def createClient(address: RpcAddress): TransportClient = { + clientFactory.createClient(address.host, address.port) + } + + private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { + val promise = Promise[Any]() + val remoteAddr = message.receiver.address + + def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning(s"Ignored failure: $e") + } + } + + def onSuccess(reply: Any): Unit = reply match { + case RpcFailure(e) => onFailure(e) + case rpcReply => + if (!promise.trySuccess(rpcReply)) { + logWarning(s"Ignored message: $reply") + } + } + + if (remoteAddr == address) { + val p = Promise[Any]() + p.future.onComplete { + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) + }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) + } else { + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) + } + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure( + new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + } + + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) + } + + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { + NettyRpcEnv.currentClient.withValue(client) { + deserialize { () => + javaSerializerInstance.deserialize[T](bytes) + } + } + } + + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.getRpcEndpointRef(endpoint) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = + new RpcEndpointAddress(address, endpointName).toString + + override def shutdown(): Unit = { + cleanup() + } + + override def awaitTermination(): Unit = { + dispatcher.awaitTermination() + } + + private def cleanup(): Unit = { + if (!stopped.compareAndSet(false, true)) { + return + } + + val iter = outboxes.values().iterator() + while (iter.hasNext()) { + val outbox = iter.next() + outboxes.remove(outbox.address) + outbox.stop() + } + if (timeoutScheduler != null) { + timeoutScheduler.shutdownNow() + } + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + if (dispatcher != null) { + dispatcher.stop() + } + if (clientConnectionExecutor != null) { + clientConnectionExecutor.shutdownNow() + } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } + } + + override def deserialize[T](deserializationAction: () => T): T = { + NettyRpcEnv.currentEnv.withValue(this) { + deserializationAction() + } + } + + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = { + error = e + source.close() + } + + override def read(dst: ByteBuffer): Int = { + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error + } + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + +} + +private[netty] object NettyRpcEnv extends Logging { + + /** + * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. + * Use `currentEnv` to wrap the deserialization codes. E.g., + * + * {{{ + * NettyRpcEnv.currentEnv.withValue(this) { + * your deserialization codes + * } + * }}} + */ + private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) + + /** + * Similar to `currentEnv`, this variable references the client instance associated with an + * RPC, in case it's needed to find out the remote address during deserialization. + */ + private[netty] val currentClient = new DynamicVariable[TransportClient](null) + +} + +private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { + + def create(config: RpcEnvConfig): RpcEnv = { + val sparkConf = config.conf + // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support + // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance + val javaSerializerInstance = + new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val nettyEnv = + new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) + if (!config.clientMode) { + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.startServer(actualPort) + (nettyEnv, nettyEnv.address.port) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } + } + nettyEnv + } +} + +/** + * The NettyRpcEnv version of RpcEndpointRef. + * + * This class behaves differently depending on where it's created. On the node that "owns" the + * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance. + * + * On other machines that receive a serialized version of the reference, the behavior changes. The + * instance will keep track of the TransportClient that sent the reference, so that messages + * to the endpoint are sent over the client connection, instead of needing a new connection to + * be opened. + * + * The RpcAddress of this ref can be null; what that means is that the ref can only be used through + * a client connection, since the process hosting the endpoint is not listening for incoming + * connections. These refs should not be shared with 3rd parties, since they will not be able to + * send messages to the endpoint. + * + * @param conf Spark configuration. + * @param endpointAddress The address where the endpoint is listening. + * @param nettyEnv The RpcEnv associated with this ref. + */ +private[netty] class NettyRpcEndpointRef( + @transient private val conf: SparkConf, + endpointAddress: RpcEndpointAddress, + @transient @volatile private var nettyEnv: NettyRpcEnv) + extends RpcEndpointRef(conf) with Serializable with Logging { + + @transient @volatile var client: TransportClient = _ + + private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null + private val _name = endpointAddress.name + + override def address: RpcAddress = if (_address != null) _address.rpcAddress else null + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + nettyEnv = NettyRpcEnv.currentEnv.value + client = NettyRpcEnv.currentClient.value + } + + private def writeObject(out: ObjectOutputStream): Unit = { + out.defaultWriteObject() + } + + override def name: String = _name + + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) + } + + override def send(message: Any): Unit = { + require(message != null, "Message is null") + nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) + } + + override def toString: String = s"NettyRpcEndpointRef(${_address})" + + def toURI: URI = new URI(_address.toString) + + final override def equals(that: Any): Boolean = that match { + case other: NettyRpcEndpointRef => _address == other._address + case _ => false + } + + final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() +} + +/** + * The message that is sent from the sender to the receiver. + */ +private[netty] case class RequestMessage( + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) + +/** + * A response that indicates some failure happens in the receiver side. + */ +private[netty] case class RpcFailure(e: Throwable) + +/** + * Dispatches incoming RPCs to registered endpoints. + * + * The handler keeps track of all client instances that communicate with it, so that the RpcEnv + * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e., + * one that is not listening for incoming connections, but rather needs to be contacted via the + * client socket). + * + * Events are sent on a per-connection basis, so if a client opens multiple connections to the + * RpcEnv, multiple connection / disconnection events will be created for that client (albeit + * with different `RpcAddress` information). + */ +private[netty] class NettyRpcHandler( + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { + + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() + + // A variable to track the remote RpcEnv addresses of all clients + private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() + + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } + + override def receive( + client: TransportClient, + message: ByteBuffer): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } + + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) + if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + } else { + // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for + // the listening address + val remoteEnvAddress = requestMessage.senderAddress + if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } + requestMessage + } + } + + override def getStreamManager: StreamManager = streamManager + + override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessConnectionError for the remote RpcEnv listening address + val remoteEnvAddress = remoteAddresses.get(clientAddr) + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress)) + } + } else { + // If the channel is closed before connecting, its remoteAddress will be null. + // See java.net.Socket.getRemoteSocketAddress + // Because we cannot get a RpcAddress, just log it + logError("Exception before connecting to the client", cause) + } + } + + override def connectionTerminated(client: TransportClient): Unit = { + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) + dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + val remoteEnvAddress = remoteAddresses.remove(clientAddr) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessDisconnected for the remote RpcEnv listening address + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress)) + } + } else { + // If the channel is closed before connecting, its remoteAddress will be null. In this case, + // we can ignore it since we don't fire "Associated". + // See java.net.Socket.getRemoteSocketAddress + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 000000000000..394cde4fa076 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer +import org.apache.spark.util.Utils + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + * + * Three kinds of resources can be registered in this manager, all backed by actual files: + * + * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. + * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. + * - arbitrary directories; all files under the directory become available through the manager, + * respecting the directory's hierarchy. + * + * Only streaming (openStream) is supported. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + private val dirs = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case other => + val dir = dirs.get(ftype) + require(dir != null, s"Invalid stream URI: $ftype not found.") + new File(dir, fname) + } + + require(file != null && file.isFile(), s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" + } + + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null, + s"URI '$fixedBaseUri' already registered.") + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala new file mode 100644 index 000000000000..2316ebe347bb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.nio.ByteBuffer +import java.util.concurrent.Callable +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.rpc.RpcAddress + +private[netty] sealed trait OutboxMessage { + + def sendWith(client: TransportClient): Unit + + def onFailure(e: Throwable): Unit + +} + +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage + with Logging { + + override def sendWith(client: TransportClient): Unit = { + client.send(content) + } + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Failed to send one-way RPC.", e) + } + +} + +private[netty] case class RpcOutboxMessage( + content: ByteBuffer, + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, ByteBuffer) => Unit) + extends OutboxMessage with RpcResponseCallback { + + private var client: TransportClient = _ + private var requestId: Long = _ + + override def sendWith(client: TransportClient): Unit = { + this.client = client + this.requestId = client.sendRpc(content, this) + } + + def onTimeout(): Unit = { + require(client != null, "TransportClient has not yet been set.") + client.removeRpcRequest(requestId) + } + + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: ByteBuffer): Unit = { + _onSuccess(client, response) + } + +} + +private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { + + outbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + private val messages = new java.util.LinkedList[OutboxMessage] + + @GuardedBy("this") + private var client: TransportClient = null + + /** + * connectFuture points to the connect task. If there is no connect task, connectFuture will be + * null. + */ + @GuardedBy("this") + private var connectFuture: java.util.concurrent.Future[Unit] = null + + @GuardedBy("this") + private var stopped = false + + /** + * If there is any thread draining the message queue + */ + @GuardedBy("this") + private var draining = false + + /** + * Send a message. If there is no active connection, cache it and launch a new connection. If + * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]]. + */ + def send(message: OutboxMessage): Unit = { + val dropped = synchronized { + if (stopped) { + true + } else { + messages.add(message) + false + } + } + if (dropped) { + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + } else { + drainOutbox() + } + } + + /** + * Drain the message queue. If there is other draining thread, just exit. If the connection has + * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the + * connection. + */ + private def drainOutbox(): Unit = { + var message: OutboxMessage = null + synchronized { + if (stopped) { + return + } + if (connectFuture != null) { + // We are connecting to the remote address, so just exit + return + } + if (client == null) { + // There is no connect task but client is null, so we need to launch the connect task. + launchConnectTask() + return + } + if (draining) { + // There is some thread draining, so just exit + return + } + message = messages.poll() + if (message == null) { + return + } + draining = true + } + while (true) { + try { + val _client = synchronized { client } + if (_client != null) { + message.sendWith(_client) + } else { + assert(stopped == true) + } + } catch { + case NonFatal(e) => + handleNetworkFailure(e) + return + } + synchronized { + if (stopped) { + return + } + message = messages.poll() + if (message == null) { + draining = false + return + } + } + } + } + + private def launchConnectTask(): Unit = { + connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] { + + override def call(): Unit = { + try { + val _client = nettyEnv.createClient(address) + outbox.synchronized { + client = _client + if (stopped) { + closeClient() + } + } + } catch { + case ie: InterruptedException => + // exit + return + case NonFatal(e) => + outbox.synchronized { connectFuture = null } + handleNetworkFailure(e) + return + } + outbox.synchronized { connectFuture = null } + // It's possible that no thread is draining now. If we don't drain here, we cannot send the + // messages until the next message arrives. + drainOutbox() + } + }) + } + + /** + * Stop [[Inbox]] and notify the waiting messages with the cause. + */ + private def handleNetworkFailure(e: Throwable): Unit = { + synchronized { + assert(connectFuture == null) + if (stopped) { + return + } + stopped = true + closeClient() + } + // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along + // with a new connection + nettyEnv.removeOutbox(address) + + // Notify the connection failure for the remaining messages + // + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.onFailure(e) + message = messages.poll() + } + assert(messages.isEmpty) + } + + private def closeClient(): Unit = synchronized { + // Not sure if `client.close` is idempotent. Just for safety. + if (client != null) { + client.close() + } + client = null + } + + /** + * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a + * [[SparkException]]. + */ + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + if (connectFuture != null) { + connectFuture.cancel(true) + } + closeClient() + } + + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message = messages.poll() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala new file mode 100644 index 000000000000..d2e94f943aba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcAddress + +/** + * An address identifier for an RPC endpoint. + * + * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only + * connection and can only be reached via the client that sent the endpoint reference. + * + * @param rpcAddress The socket address of the endpint. + * @param name Name of the endpoint. + */ +private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { + + require(name != null, "RpcEndpoint name must be provided.") + + def this(host: String, port: Int, name: String) = { + this(RpcAddress(host, port), name) + } + + override val toString = if (rpcAddress != null) { + s"spark://$name@${rpcAddress.host}:${rpcAddress.port}" + } else { + s"spark-client://$name" + } +} + +private[netty] object RpcEndpointAddress { + + def apply(sparkUrl: String): RpcEndpointAddress = { + try { + val uri = new java.net.URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid Spark URL: " + sparkUrl) + } + new RpcEndpointAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala similarity index 57% rename from core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala rename to core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index cf362f846473..99f20da2d66a 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -15,29 +15,26 @@ * limitations under the License. */ -package org.apache.spark.executor +package org.apache.spark.rpc.netty -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * + * This is used when setting up a remote endpoint reference. */ -private[spark] -class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { +private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case TriggerThreadDump => - context.reply(Utils.getThreadDump()) + case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name)) } - } -object ExecutorEndpoint { - val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +private[netty] object RpcEndpointVerifier { + val NAME = "endpoint-verifier" + + /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + case class CheckExistence(name: String) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index e0edd7d4ae96..146cfb9ba803 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,26 +24,42 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value + this.update == acc.update && this.value == acc.value && + this.internal == acc.internal case _ => false } + + override def hashCode(): Int = { + val state = Seq(id, name, update, value, internal) + state.map(_.hashCode).reduceLeft(31 * _ + _) + } } object AccumulableInfo { + def apply( + id: Long, + name: String, + update: Option[String], + value: String, + internal: Boolean): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal) + } + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value) + new AccumulableInfo(id, name, update, value, internal = false) } def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value) + new AccumulableInfo(id, name, None, value, internal = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d..a3d2db31301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c4fa277c2125..b128ed50cad5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,7 +23,8 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} +import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -34,6 +35,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout @@ -45,17 +47,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -82,7 +132,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -111,7 +161,7 @@ class DAGScheduler( * * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). */ - private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]] + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -136,33 +186,24 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - // Flag to control if reduce tasks are assigned preferred locations - private val shuffleLocalityEnabled = - sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) - // Number of map, reduce tasks above which we do not assign preferred locations - // based on map output sizes. We limit the size of jobs for which assign preferred locations - // as computing the top locations by size becomes expensive. - private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 - // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that - private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 - - // Fraction of total map output that must be at a location for it to considered as a preferred - // location for a reduce task. - // Making this larger will focus on fewer locations where most data can be read locally, but - // may lead to more delay in scheduling if those locations are busy. - private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 - - // Called by TaskScheduler to report task's starting. + /** + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) } - // Called to report that a task has completed and results are being fetched remotely. + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo) { eventProcessLoop.post(GettingResultEvent(taskInfo)) } - // Called by TaskScheduler to report task completions or failures. + /** + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( task: Task[_], reason: TaskEndReason, @@ -188,29 +229,35 @@ class DAGScheduler( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } - // Called by TaskScheduler when an executor fails. + /** + * Called by TaskScheduler implementation when an executor fails. + */ def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } - // Called by TaskScheduler when a host is added + /** + * Called by TaskScheduler implementation when a host is added. + */ def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or - // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] - def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { // Note: if the storage level is NONE, we don't need to get locations from block manager. - val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { - Seq.fill(rdd.partitions.size)(Nil) + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) } else { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] @@ -237,11 +284,12 @@ class DAGScheduler( case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, firstJobId) + getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - stage } } @@ -281,12 +329,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -302,20 +350,22 @@ class DAGScheduler( shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd - val numTasks = rdd.partitions.size + val numTasks = rdd.partitions.length val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + (0 until locs.length).foreach { i => + if (locs(i) ne null) { + // locs(i) will be null if missing + stage.addOutputLoc(i, locs(i)) + } } - stage.numAvailableOutputs = locs.count(_ != null) } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage } @@ -352,16 +402,6 @@ class DAGScheduler( parents.toList } - /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { - val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (parentsWithNoMapStage.nonEmpty) { - val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) - shuffleToMapStage(currentShufDep.shuffleId) = stage - } - } - /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] @@ -378,11 +418,9 @@ class DAGScheduler( if (!shuffleToMapStage.contains(shufDep.shuffleId)) { parents.push(shufDep) } - - waitingForVisit.push(shufDep.rdd) case _ => - waitingForVisit.push(dep.rdd) } + waitingForVisit.push(dep.rdd) } } } @@ -498,12 +536,27 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => r.removeActiveJob() + case m: ShuffleMapStage => m.removeActiveJob(job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object - * can be used to block until the the job finishes executing or can be used to cancel the job. + * Submit an action job to the scheduler. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], @@ -522,6 +575,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -534,6 +588,20 @@ class DAGScheduler( waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @throws Exception when the job fails + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -543,11 +611,12 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter.completionFuture, atMost = Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. @@ -557,6 +626,17 @@ class DAGScheduler( } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -566,13 +646,48 @@ class DAGScheduler( properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray + val partitions = (0 until rdd.partitions.length).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -581,6 +696,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -677,14 +795,18 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } private[scheduler] def cleanUpAfterSchedulerStop() { for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") job.listener.jobFailed(error) // Tell the listeners that all of the running stages have ended. Don't bother // cancelling the stages because if the DAG scheduler is stopped, the entire application @@ -715,31 +837,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val jobSubmissionTime = clock.getTimeMillis() - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.setActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.addActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (finalStage.isAvailable) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) + } + submitWaitingStages() } @@ -762,7 +930,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -770,35 +938,42 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { - stage match { - case stage: ShuffleMapStage => - (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) - case stage: ResultStage => - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) - } + val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() + + // Create internal accumulators if the stage has no accumulators initialized. + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { + stage.resetInternalAccumulators() } - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - outputCommitCoordinator.stageStart(stage.id) - val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + case s: ResultStage => + outputCommitCoordinator.stageStart( + stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) + } + val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { stage match { case s: ShuffleMapStage => partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap case s: ResultStage => - val job = s.resultOfJob.get + val job = s.activeJob.get partitionsToCompute.map { id => - val p = job.partitions(id) + val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } @@ -806,7 +981,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -826,22 +1001,23 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -852,31 +1028,33 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) } case stage: ResultStage => - val job = stage.resultOfJob.get + val job = stage.activeJob.get partitionsToCompute.map { id => - val p: Int = job.partitions(id) + val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) } } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -916,9 +1094,11 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { @@ -939,8 +1119,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. @@ -960,13 +1143,13 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask // TODO Refactor this out to a function that accepts a ResultStage val resultStage = stage.asInstanceOf[ResultStage] - resultStage.resultOfJob match { + resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -1006,7 +1189,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1021,45 +1204,35 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + shuffleStage.outputLocInMapOutputTrackerFormat(), changeEpoch = true) clearCacheLocs() - if (shuffleStage.outputLocs.contains(Nil)) { + + if (!shuffleStage.isAvailable) { // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (shuffleStage <- waitingStages) { - logInfo("Missing parents for " + shuffleStage + ": " + - getMissingParentStages(shuffleStage)) - } - for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) - { - newlyRunnable += shuffleStage - } - waitingStages --= newlyRunnable - runningStages ++= newlyRunnable - for { - shuffleStage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(shuffleStage) - } { - logInfo("Submitting " + shuffleStage + " (" + - shuffleStage.rdd + "), which is now runnable") - submitMissingTasks(shuffleStage, jobId) + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } } } + + // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) @@ -1070,7 +1243,6 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") } else { - // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. @@ -1084,7 +1256,13 @@ class DAGScheduler( } if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) + } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { + abortStage(failedStage, s"$failedStage (${failedStage.name}) " + + s"has failed the maximum allowable number of " + + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + + s"Most recent failure reason: ${failureMessage}", None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1112,13 +1290,13 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case other => + case _: ExecutorLostFailure | TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } @@ -1150,8 +1328,10 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + mapOutputTracker.registerMapOutputs( + shuffleId, + stage.outputLocInMapOutputTrackerFormat(), + changeEpoch = true) } if (shuffleToMapStage.isEmpty) { mapOutputTracker.incrementEpoch() @@ -1208,10 +1388,17 @@ class DAGScheduler( if (errorMessage.isEmpty) { logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() } else { stage.latestInfo.stageFailed(errorMessage.get) logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) } + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage @@ -1221,7 +1408,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1230,7 +1420,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1238,8 +1428,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1358,45 +1551,41 @@ class DAGScheduler( return rddPrefs.map(TaskLocation(_)) } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. rdd.dependencies.foreach { case n: NarrowDependency[_] => - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } - case s: ShuffleDependency[_, _, _] => - // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION - // of data as preferred locations - if (shuffleLocalityEnabled && - rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && - s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { - // Get the preferred map output locations for this reducer - val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, - partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) - if (topLocsForReducer.nonEmpty) { - return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) - } - } case _ => } + Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } @@ -1421,6 +1610,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) @@ -1448,8 +1640,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a213d419cf03..dda3b6cc7f96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], @@ -45,6 +46,15 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent @@ -73,6 +83,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 5a06ef02f5c5..eaa07acc5132 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -109,7 +109,9 @@ private[spark] class EventLoggingListener( if (shouldOverwrite && fileSystem.exists(path)) { logWarning(s"Event log $path already exists. Overwriting...") - fileSystem.delete(path, true) + if (!fileSystem.delete(path, true)) { + logWarning(s"Error deleting $path") + } } /* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844). @@ -205,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. @@ -216,7 +222,9 @@ private[spark] class EventLoggingListener( if (fileSystem.exists(target)) { if (shouldOverwrite) { logWarning(s"Event log $target already exists. Overwriting...") - fileSystem.delete(target, true) + if (!fileSystem.delete(target, true)) { + logWarning(s"Error deleting $target") + } } else { throw new IOException("Target log file already exists (%s)".format(logPath)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 2bc43a918644..7e1197d74280 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -23,16 +23,34 @@ import org.apache.spark.executor.ExecutorExitCode * Represents an explanation for a executor or whole slave failing or exiting. */ private[spark] -class ExecutorLossReason(val message: String) { +class ExecutorLossReason(val message: String) extends Serializable { override def toString: String = message } private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +case class ExecutorExited(exitCode: Int, exitCausedByApp: Boolean, reason: String) + extends ExecutorLossReason(reason) + +private[spark] object ExecutorExited { + def apply(exitCode: Int, exitCausedByApp: Boolean): ExecutorExited = { + ExecutorExited( + exitCode, + exitCausedByApp, + ExecutorExitCode.explainExitCode(exitCode)) + } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + +/** + * A loss reason that means we don't yet know why the executor exited. + * + * This is used by the task scheduler to remove state associated with the executor, but + * not yet fail any tasks that were running in the executor before the real loss reason + * is known. + */ +private [spark] object LossReasonPending extends ExecutorLossReason("Pending loss reason.") + private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} + extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index bac37bfdaa23..0e438ab4366d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -107,7 +107,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) - for (split <- list) { + for (split <- list.asScala) { retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a..4326135186a7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -28,17 +32,15 @@ private[spark] class JobWaiter[T]( resultHandler: (Int, T) => Unit) extends JobListener { - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished - + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +51,17 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } + override def jobFailed(exception: Exception): Unit = + jobPromise.failure(exception) - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() - } - return jobResult - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1efce124c0a6..b2e9a97129f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -122,8 +122,7 @@ private[spark] class CompressedMapStatus( /** * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, - * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap - * is compressed. + * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks @@ -194,6 +193,8 @@ private[spark] object HighlyCompressedMapStatus { } else { 0 } + emptyBlocks.trim() + emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 8321037cdc02..4d146678174f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,10 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int + + private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -56,8 +58,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -93,9 +95,21 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } } - // Called by DAGScheduler - private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + /** + * Called by the DAGScheduler when a stage starts. + * + * @param stage the stage id. + * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. + * the maximum possible value of `context.partitionId`). + */ + private[scheduler] def stageStart( + stage: StageId, + maxPartitionId: Int): Unit = { + val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) + java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) + synchronized { + authorizedCommittersByStage(stage) = arr + } } // Called by DAGScheduler @@ -107,7 +121,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,13 +131,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") - authorizedCommitters.remove(partition) + if (authorizedCommitters(partition) == attemptNumber) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") + authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -140,21 +154,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => - authorizedCommitters.get(partition) match { - case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") - false - case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + authorizedCommitters(partition) match { + case NO_AUTHORIZED_COMMITTER => + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true + case existingCommitter => + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -162,7 +178,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[spark] object OutputCommitCoordinator { - // This actor is used only for RPC + // This endpoint is used only for RPC private[spark] class OutputCommitCoordinatorEndpoint( override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { @@ -174,9 +190,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc..551e39a81b69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -74,7 +74,7 @@ private[spark] class Pool( if (schedulableNameToSchedulable.containsKey(schedulableName)) { return schedulableNameToSchedulable.get(schedulableName) } - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched @@ -83,13 +83,13 @@ private[spark] class Pool( null } - override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) + override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) { + schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { shouldRevive |= schedulable.checkSpeculatableTasks() } shouldRevive @@ -98,7 +98,7 @@ private[spark] class Pool( override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = - schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) + schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca481..d1687830ff7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,24 +17,51 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). - var resultOfJob: Option[ActiveJob] = None + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ + private[this] var _activeJob: Option[ActiveJob] = None + + def activeJob: Option[ActiveJob] = _activeJob + + def setActiveJob(job: ActiveJob): Unit = { + _activeJob = Option(job) + } + + def removeActiveJob(): Unit = { + _activeJob = None + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + * + * This can only be called when there is an active job. + */ + override def findMissingPartitions(): Seq[Int] = { + val job = activeJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) + } override def toString: String = "ResultStage " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 9c2606e278c5..fb693721a9cb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -44,9 +44,11 @@ private[spark] class ResultTask[T, U]( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { + locs: Seq[TaskLocation], + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index a87ef030e69c..ab00bc8f0bf4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,7 +42,7 @@ private[spark] trait Schedulable { def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit + def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 66c75f325fcd..51416e5ce97f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -35,19 +43,61 @@ private[spark] class ShuffleMapStage( val shuffleDep: ShuffleDependency[_, _, _]) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + private[this] var _mapStageJobs: List[ActiveJob] = Nil + + private[this] var _numAvailableOutputs: Int = 0 + + /** + * List of [[MapStatus]] for each partition. The index of the array is the map partition id, + * and each value in the array is the list of possible [[MapStatus]] for a partition + * (a single task might run multiple times). + */ + private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + override def toString: String = "ShuffleMapStage " + id - var numAvailableOutputs: Long = 0 + /** + * Returns the list of active jobs, + * i.e. map-stage jobs that were submitted to execute this stage independently (if any). + */ + def mapStageJobs: Seq[ActiveJob] = _mapStageJobs + + /** Adds the job to the active job list. */ + def addActiveJob(job: ActiveJob): Unit = { + _mapStageJobs = job :: _mapStageJobs + } + + /** Removes the job from the active job list. */ + def removeActiveJob(job: ActiveJob): Unit = { + _mapStageJobs = _mapStageJobs.filter(_ != job) + } - def isAvailable: Boolean = numAvailableOutputs == numPartitions + /** + * Number of partitions that have shuffle outputs. + * When this reaches [[numPartitions]], this map stage is ready. + * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. + */ + def numAvailableOutputs: Int = _numAvailableOutputs - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + /** + * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. + * This should be the same as `outputLocs.contains(Nil)`. + */ + def isAvailable: Boolean = _numAvailableOutputs == numPartitions + + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + override def findMissingPartitions(): Seq[Int] = { + val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } def addOutputLoc(partition: Int, status: MapStatus): Unit = { val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList if (prevList == Nil) { - numAvailableOutputs += 1 + _numAvailableOutputs += 1 } } @@ -56,10 +106,19 @@ private[spark] class ShuffleMapStage( val newList = prevList.filterNot(_.location == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 + _numAvailableOutputs -= 1 } } + /** + * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned + * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, + * that position is filled with null. + */ + def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { + outputLocs.map(_.headOption.orNull) + } + /** * Removes all shuffle outputs associated with this executor. Note that this will also remove * outputs which are served by an external shuffle server (if one exists), as they are still @@ -73,12 +132,12 @@ private[spark] class ShuffleMapStage( outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true - numAvailableOutputs -= 1 + _numAvailableOutputs -= 1 } } if (becameUnavailable) { logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) + this, execId, _numAvailableOutputs, numPartitions, isAvailable)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 14c8c0096148..ea97ef0e746d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -27,11 +27,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter /** -* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner -* specified in the ShuffleDependency). -* -* See [[org.apache.spark.scheduler.Task]] for more information. -* + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * * @param stageId id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). @@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f..075a7f13172d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aa..95722a07144e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 40a333a3e06b..7ea24a217bd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,29 +24,35 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ -private[spark] abstract class Stage( +private[scheduler] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, @@ -55,18 +61,34 @@ private[spark] abstract class Stage( val callSite: CallSite) extends Logging { - val numPartitions = rdd.partitions.size + val numPartitions = rdd.partitions.length /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 - val name = callSite.shortForm - val details = callSite.longForm + val name: String = callSite.shortForm + val details: String = callSite.longForm + + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + + /** + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. + */ + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) + } /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized @@ -76,6 +98,29 @@ private[spark] abstract class Stage( */ private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + /** + * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these + * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * We keep track of each attempt ID that has failed to avoid recording duplicate failures if + * multiple tasks from the same stage attempt fail (SPARK-5945). + */ + private val fetchFailedAttemptIds = new HashSet[Int] + + private[scheduler] def clearFailures() : Unit = { + fetchFailedAttemptIds.clear() + } + + /** + * Check whether we should abort the failedStage due to multiple consecutive fetch failures. + * + * This method updates the running set of failed stage attempts and returns + * true if the number of failures exceeds the allowable number of failures. + */ + private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { + fetchFailedAttemptIds.add(stageAttemptId) + fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES + } + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -89,8 +134,17 @@ private[spark] abstract class Stage( def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } + + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + def findMissingPartitions(): Seq[Int] +} + +private[scheduler] object Stage { + // The number of consecutive failures allowed before a stage is aborted + val MAX_CONSECUTIVE_FETCH_FAILURES = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1978305cfefb..9f27eed626be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,18 +23,18 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.unsafe.memory.TaskMemoryManager -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A unit of execution. We have two kinds of Task's in Spark: - * - [[org.apache.spark.scheduler.ShuffleMapTask]] - * - [[org.apache.spark.scheduler.ResultTask]] + * + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] * * A Spark job consists of one or more stages. The very last stage in a job consists of multiple * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task @@ -47,7 +47,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, - var partitionId: Int) extends Serializable { + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator @@ -68,12 +69,13 @@ private[spark] abstract class Task[T]( metricsSystem: MetricsSystem) : (T, AccumulatorUpdates) = { context = new TaskContextImpl( - stageId = stageId, - partitionId = partitionId, - taskAttemptId = taskAttemptId, - attemptNumber = attemptNumber, - taskMemoryManager = taskMemoryManager, - metricsSystem = metricsSystem, + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) @@ -87,13 +89,15 @@ private[spark] abstract class Task[T]( } finally { context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for shuffles - SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() - } Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } } } finally { TaskContext.unset() @@ -173,7 +177,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -192,9 +196,9 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) - ByteBuffer.wrap(out.toByteArray) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) + out.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced7770..f113c2b1b843 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index da07ce2c6ea4..1eb6c1614fc0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -31,7 +31,9 @@ private[spark] sealed trait TaskLocation { */ private [spark] case class ExecutorCacheTaskLocation(override val host: String, executorId: String) - extends TaskLocation + extends TaskLocation { + override def toString: String = s"${TaskLocation.executorLocationTag}${host}_$executorId" +} /** * A location on a host. @@ -53,6 +55,9 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" + // Identify locations of executors with this prefix. + val executorLocationTag = "executor_" + def apply(host: String, executorId: String): TaskLocation = { new ExecutorCacheTaskLocation(host, executorId) } @@ -65,9 +70,17 @@ private[spark] object TaskLocation { def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { - new HostTaskLocation(str) + if (str.startsWith(executorLocationTag)) { + val splits = str.split("_") + if (splits.length != 3) { + throw new IllegalArgumentException("Illegal executor location format: " + str) + } + new ExecutorCacheTaskLocation(splits(1), splits(2)) + } else { + new HostTaskLocation(str) + } } else { - new HostTaskLocation(hstr) + new HDFSCacheTaskLocation(hstr) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2e..f4965994d827 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index f25f3ed0d903..cb9a3008107d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -22,7 +22,8 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId /** - * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. + * Low-level task scheduler interface, currently implemented exclusively by + * [[org.apache.spark.scheduler.TaskSchedulerImpl]]. * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1705e7f962de..bdf19f9f277d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -87,8 +87,8 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] + // Number of tasks running on each executor + private val executorIdToTaskCount = new HashMap[String, Int] // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -254,6 +254,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId + executorIdToTaskCount(execId) += 1 executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) @@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl( var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host - activeExecutorIds += o.executorId + executorIdToTaskCount.getOrElseUpdate(o.executorId, 0) if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -331,8 +332,10 @@ private[spark] class TaskSchedulerImpl( if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) + + if (executorIdToTaskCount.contains(execId)) { + removeExecutor(execId, + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) } } @@ -340,7 +343,11 @@ private[spark] class TaskSchedulerImpl( case Some(taskSet) => if (TaskState.isFinished(state)) { taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid) + taskIdToExecutorId.remove(tid).foreach { execId => + if (executorIdToTaskCount.contains(execId)) { + executorIdToTaskCount(execId) -= 1 + } + } } if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) @@ -461,17 +468,27 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (activeExecutorIds.contains(executorId)) { + if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -481,9 +498,26 @@ private[spark] class TaskSchedulerImpl( } } - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + + /** + * Remove an executor from all our data structures and mark it as lost. If the executor's loss + * reason is not yet known, do not yet remove its association with its host nor update the status + * of any running tasks, since the loss reason defines whether we'll fail those tasks. + */ + private def removeExecutor(executorId: String, reason: ExecutorLossReason) { + executorIdToTaskCount -= executorId + val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) execs -= executorId @@ -496,8 +530,11 @@ private[spark] class TaskSchedulerImpl( } } } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) + + if (reason != LossReasonPending) { + executorIdToHost -= executorId + rootPool.executorLost(executorId, host, reason) + } } def executorAdded(execId: String, host: String) { @@ -517,7 +554,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) + executorIdToTaskCount.contains(execId) + } + + def isExecutorBusy(execId: String): Boolean = synchronized { + executorIdToTaskCount.getOrElse(execId, -1) > 0 } // By default, rack is unknown diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index be8526ba9b94..517c8991aed7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -29,7 +29,7 @@ private[spark] class TaskSet( val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + stageAttemptId + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5..a02f3017cb6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -177,14 +177,11 @@ private[spark] class TaskSetManager( var emittedTaskSizeWarning = false - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there + /** Add a task to all the pending-task lists that it should be on. */ + private def addPendingTask(index: Int) { + // Utility method that adds `index` to a list only if it's not already there def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { + if (!list.contains(index)) { list += index } } @@ -219,9 +216,7 @@ private[spark] class TaskSetManager( addTo(pendingTasksWithNoPrefs) } - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } + allPendingTasks += index // No point scanning this whole list to find the old task there } /** @@ -487,8 +482,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, @@ -662,7 +657,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +666,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,38 +702,46 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception + + case e: ExecutorLostFailure if !e.exitCausedByApp => + logInfo(s"Task $tid failed because while it was being computed, its executor" + + "exited for a reason unrelated to the task. Not counting this failure towards the " + + "maximum number of failures for the task.") + None case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { - // If a task failed because its attempt to commit was denied, do not count this failure - // towards failing the stage. This is intended to prevent spurious stage failures in cases - // where many speculative tasks are launched and denied to commit. + if (!isZombie && state != TaskState.KILLED + && reason.isInstanceOf[TaskFailedReason] + && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } @@ -774,19 +778,7 @@ private[spark] class TaskSetManager( } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because dequeueTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding = true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding = true) - } - + override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor @@ -805,9 +797,14 @@ private[spark] class TaskSetManager( } } } - // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) + val exitCausedByApp: Boolean = reason match { + case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false + case _ => true + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, + Some(reason.toString))) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 06f5438433b6..f3d0d8547677 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -35,9 +36,13 @@ private[spark] object CoarseGrainedClusterMessages { case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage - case object RegisteredExecutor extends CoarseGrainedClusterMessage + sealed trait RegisterExecutorResponse + + case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage + with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage + with RegisterExecutorResponse // Executors to driver case class RegisterExecutor( @@ -46,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages { hostPort: String, cores: Int, logUrls: Map[String, String]) - extends CoarseGrainedClusterMessage { - Utils.checkHostPort(hostPort, "Expected host port") - } + extends CoarseGrainedClusterMessage case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends CoarseGrainedClusterMessage @@ -70,7 +73,8 @@ private[spark] object CoarseGrainedClusterMessages { case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) + extends CoarseGrainedClusterMessage case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage @@ -92,6 +96,13 @@ private[spark] object CoarseGrainedClusterMessages { hostToLocalTaskCount: Map[String, Int]) extends CoarseGrainedClusterMessage + // Check if an executor was force-killed but for a reason unrelated to the running tasks. + // This could be the case if the executor is preempted, for instance. + case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage + // Used internally by executors to shut themselves down. + case object Shutdown extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6acf8a9a5e9b..7efe16749e59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** @@ -63,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it protected var hostToLocalTaskCount: Map[String, Int] = Map.empty @@ -72,6 +75,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The number of pending tasks which is locality required protected var localityAwareTasks = 0 + // Executors that have been lost, but for which we don't yet know the real exit reason. + protected val executorsPendingLossReason = new HashSet[String] + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -82,7 +88,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[RpcAddress, String] + protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -124,21 +130,27 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } - } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { - logInfo("Registered executor: " + executorRef + " with ID " + executorId) - addressToExecutorId(executorRef.address) = executorId + // If the executor's rpc env is not listening for incoming connections, `hostPort` + // will be null, and the client connection should be used to contact the executor. + val executorAddress = if (executorRef.address != null) { + executorRef.address + } else { + context.senderAddress + } + logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId") + addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host, + cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -149,7 +161,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } // Note: some tests expect the reply to come after we put the executor in the map - context.reply(RegisteredExecutor) + context.reply(RegisteredExecutor(executorAddress.host)) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() @@ -177,7 +189,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { // Filter out executors under killing - val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq @@ -185,14 +197,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, - "remote Rpc client disassociated")) + addressToExecutorId + .get(remoteAddress) + .foreach(removeExecutor(_, SlaveLost("Remote RPC client disassociated. Likely due to " + + "containers exceeding thresholds, or network issues. Check driver logs for WARN " + + "messages."))) } // Make fake resource offers on just one executor private def makeOffers(executorId: String) { // Filter out executors under killing - if (!executorsPendingToRemove.contains(executorId)) { + if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = Seq( new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) @@ -200,6 +215,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + private def executorIsAlive(executorId: String): Boolean = synchronized { + !executorsPendingToRemove.contains(executorId) && + !executorsPendingLossReason.contains(executorId) + } + // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { @@ -227,25 +247,52 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String): Unit = { + def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId + executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, SlaveLost(reason)) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( - SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") } } + /** + * Stop making resource offers for the given executor. The executor is marked as lost with + * the loss reason still pending. + * + * @return Whether executor should be disabled + */ + protected def disableExecutor(executorId: String): Boolean = { + val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { + if (executorIsAlive(executorId)) { + executorsPendingLossReason += executorId + true + } else { + // Returns true for explicitly killed executors, we also need to get pending loss reasons; + // For others return false. + executorsPendingToRemove.contains(executorId) + } + } + + if (shouldDisable) { + logInfo(s"Disabling executor $executorId.") + scheduler.executorLost(executorId, LossReasonPending) + } + + shouldDisable + } + override def onStop() { reviveThread.shutdownNow() } @@ -263,8 +310,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint( - CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) + driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + } + + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new DriverEndpoint(rpcEnv, properties) } def stopExecutors() { @@ -291,6 +341,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only + * be called in the yarn-client mode when AM re-registers after a failure, also dynamic + * allocation is enabled. + * */ + protected def reset(): Unit = synchronized { + if (Utils.isDynamicAllocationEnabled(conf)) { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + + // Remove all the lingering executors that should be removed but not yet. The reason might be + // because (1) disconnected event is not yet received; (2) executors die silently. + executorDataMap.toMap.foreach { case (eid, _) => + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + } + } + } + override def reviveOffers() { driverEndpoint.send(ReviveOffers) } @@ -304,7 +373,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: ExecutorLossReason) { try { driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { @@ -405,33 +474,49 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { - killExecutors(executorIds, replace = false) + killExecutors(executorIds, replace = false, force = false) } /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones + * @param force whether to force kill busy executors * @return whether the kill request is acknowledged. */ - final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { + final def killExecutors( + executorIds: Seq[String], + replace: Boolean, + force: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) unknownExecutors.foreach { id => logWarning(s"Executor to kill $id does not exist!") } + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) + val executorsToKill = knownExecutors + .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => force || !scheduler.isExecutorBusy(id) } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. if (!replace) { - doRequestTotalExecutors(numExistingExecutors + numPendingExecutors - - executorsPendingToRemove.size - knownExecutors.size) + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } else { + numPendingExecutors += knownExecutors.size } - executorsPendingToRemove ++= knownExecutors - doKillExecutors(knownExecutors) + doKillExecutors(executorsToKill) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 26e72c0bff38..626a2b7d69ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -22,7 +22,7 @@ import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorEndpoint The ActorRef representing this executor + * @param executorEndpoint The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0324c9dab910..641638a77d5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -65,7 +65,9 @@ private[spark] class SimrSchedulerBackend( override def stop() { val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) - fs.delete(new Path(driverFilePath), false) + if (!fs.delete(new Path(driverFilePath), false)) { + logWarning(s"error deleting ${driverFilePath}") + } super.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bbe51b4a09a2..5105475c760e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,7 +23,8 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( @@ -36,6 +37,9 @@ private[spark] class SparkDeploySchedulerBackend( private var client: AppClient = null private var stopping = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ @volatile private var appId: String = _ @@ -47,6 +51,7 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() + launcherBackend.connect() // The endpoint for executors to talk to us val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, @@ -87,24 +92,20 @@ private[spark] class SparkDeploySchedulerBackend( command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) waitForRegistration() + launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop() { - stopping = true - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) - } + override def stop(): Unit = synchronized { + stop(SparkAppHandle.State.FINISHED) } override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) this.appId = appId notifyContext() + launcherBackend.setAppId(appId) } override def disconnected() { @@ -117,6 +118,7 @@ private[spark] class SparkDeploySchedulerBackend( override def dead(reason: String) { notifyContext() if (!stopping) { + launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) @@ -135,11 +137,11 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) + case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) - removeExecutor(fullId.split("/")(1), reason.toString) + removeExecutor(fullId.split("/")(1), reason) } override def sufficientResourcesRegistered(): Boolean = { @@ -188,4 +190,21 @@ private[spark] class SparkDeploySchedulerBackend( registrationBarrier.release() } + private def stop(finalState: SparkAppHandle.State): Unit = synchronized { + try { + stopping = true + + super.stop() + client.stop() + + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala deleted file mode 100644 index 044f6288fabd..000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import scala.concurrent.{Future, ExecutionContext} - -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{ThreadUtils, RpcUtils} - -import scala.util.control.NonFatal - -/** - * Abstract Yarn scheduler backend that contains common logic - * between the client and cluster Yarn scheduler backends. - */ -private[spark] abstract class YarnSchedulerBackend( - scheduler: TaskSchedulerImpl, - sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { - minRegisteredRatio = 0.8 - } - - protected var totalExpectedExecutors = 0 - - private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - - private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) - - /** - * Request executors from the ApplicationMaster by specifying the total number desired. - * This includes executors already pending or running. - */ - override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) - } - - /** - * Request that the ApplicationMaster kill the specified executors. - */ - override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio - } - - /** - * Add filters to the SparkUI. - */ - private def addWebUIFilter( - filterName: String, - filterParams: Map[String, String], - proxyBase: String): Unit = { - if (proxyBase != null && proxyBase.nonEmpty) { - System.setProperty("spark.ui.proxyBase", proxyBase) - } - - val hasFilter = - filterName != null && filterName.nonEmpty && - filterParams != null && filterParams.nonEmpty - if (hasFilter) { - logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) - filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } - scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } - } - } - - /** - * An [[RpcEndpoint]] that communicates with the ApplicationMaster. - */ - private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { - private var amEndpoint: Option[RpcEndpointRef] = None - - private val askAmThreadPool = - ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") - implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) - - override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager(am) => - logInfo(s"ApplicationMaster registered as $am") - amEndpoint = Some(am) - - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - - case RemoveExecutor(executorId, reason) => - removeExecutor(executorId, reason) - } - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: RequestExecutors => - amEndpoint match { - case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](r)) - } onFailure { - case NonFatal(e) => - logError(s"Sending $r to AM was unsuccessful", e) - context.sendFailure(e) - } - case None => - logWarning("Attempted to request executors before the AM has registered!") - context.reply(false) - } - - case k: KillExecutors => - amEndpoint match { - case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](k)) - } onFailure { - case NonFatal(e) => - logError(s"Sending $k to AM was unsuccessful", e) - context.sendFailure(e) - } - case None => - logWarning("Attempted to kill executors before the AM has registered!") - context.reply(false) - } - - } - - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (amEndpoint.exists(_.address == remoteAddress)) { - logWarning(s"ApplicationMaster has disassociated: $remoteAddress") - } - } - - override def onStop(): Unit = { - askAmThreadPool.shutdownNow() - } - } -} - -private[spark] object YarnSchedulerBackend { - val ENDPOINT_NAME = "YarnScheduler" -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 15a0915708c7..7d08eae0b487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap @@ -32,7 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -101,11 +101,15 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + // A client for talking to the external shuffle service, if it is a private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) @@ -127,7 +131,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() val driver = createSchedulerDriver( - master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + CoarseMesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } @@ -194,6 +203,11 @@ private[spark] class CoarseMesosSchedulerBackend( s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } @@ -217,6 +231,10 @@ private[spark] class CoarseMesosSchedulerBackend( markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -228,55 +246,63 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse) - .addAllResources(memResourcesToUse) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + if (meetsConstraints) { + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId)) { + // Launch an executor on the slave + val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) + totalCoresAcquired += cpusToUse + val taskId = newMesosTaskId() + taskIdToSlaveId.put(taskId, slaveId) + slaveIdsWithExecutors += slaveId + coresByTaskId(taskId) = cpusToUse + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + } + + // Accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) + } else { + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } - - // accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + // This offer does not meet constraints. We don't need to see it again. + // Decline the offer for a long period of time. + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + d.declineOffer(offer.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) } } } @@ -309,9 +335,9 @@ private[spark] class CoarseMesosSchedulerBackend( } if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId(taskId) + val slaveId = taskIdToSlaveId.get(taskId) slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId + taskIdToSlaveId.remove(taskId) // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores @@ -356,10 +382,10 @@ private[spark] class CoarseMesosSchedulerBackend( stateLock.synchronized { if (slaveIdsWithExecutors.contains(slaveId)) { val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { val taskId: Int = slaveIdToTaskId.get(slaveId) taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) } // TODO: This assumes one Spark executor per Mesos slave, // which may no longer be true after SPARK-5095 @@ -406,7 +432,7 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdToTaskId = taskIdToSlaveId.inverse() for (executorId <- executorIds) { val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { mesosDriver.killTask( TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) pendingRemovedSlaveIds += slaveId diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3efc536f1456..e0c547dce6d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode @@ -129,6 +129,6 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( } override def fetchAll[T](): Iterable[T] = { - zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T]) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index f078547e7135..4a21a779d2ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -29,7 +29,6 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} - import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem @@ -350,7 +349,7 @@ private[spark] class MesosClusterScheduler( } // TODO: Page the status updates to avoid trying to reconcile // a large amount of tasks at once. - driver.reconcileTasks(statuses) + driver.reconcileTasks(statuses.toSeq.asJava) } } } @@ -375,21 +374,20 @@ private[spark] class MesosClusterScheduler( val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") envBuilder.addVariables( Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) - val cmdOptions = generateCmdOption(desc).mkString(" ") val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image") val executorUri = desc.schedulerProperties.get("spark.executor.uri") .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) - val appArguments = desc.command.arguments.mkString(" ") - val (executable, jar) = if (dockerDefined) { + // Gets the path to run spark-submit, and the path to the Mesos sandbox. + val (executable, sandboxPath) = if (dockerDefined) { // Application jar is automatically downloaded in the mounted sandbox by Mesos, // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable. - ("./bin/spark-submit", s"$$MESOS_SANDBOX/${desc.jarUrl.split("/").last}") + ("./bin/spark-submit", "$MESOS_SANDBOX") } else if (executorUri.isDefined) { builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) val folderBasename = executorUri.get.split('/').last.split('.').head val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" - val cmdJar = s"../${desc.jarUrl.split("/").last}" - (cmdExecutable, cmdJar) + // Sandbox path points to the parent folder as we chdir into the folderBasename. + (cmdExecutable, "..") } else { val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") .orElse(conf.getOption("spark.home")) @@ -398,27 +396,57 @@ private[spark] class MesosClusterScheduler( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath - val cmdJar = desc.jarUrl.split("/").last - (cmdExecutable, cmdJar) + // Sandbox points to the current directory by default with Mesos. + (cmdExecutable, ".") } - builder.setValue(s"$executable $cmdOptions $jar $appArguments") + val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") + val appArguments = desc.command.arguments.mkString(" ") + builder.setValue(s"$executable $cmdOptions $primaryResource $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + setupUris(pyFiles, builder) + } builder.build() } - private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + private def generateCmdOption(desc: MesosDriverDescription, sandboxPath: String): Seq[String] = { var options = Seq( "--name", desc.schedulerProperties("spark.app.name"), - "--class", desc.command.mainClass, "--master", s"mesos://${conf.get("spark.master")}", "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + + val replicatedOptionsBlacklist = Set( + "spark.jars" // Avoids duplicate classes in classpath + ) + + // Assume empty main class means we're running python + if (!desc.command.mainClass.equals("")) { + options ++= Seq("--class", desc.command.mainClass) + } + desc.schedulerProperties.get("spark.executor.memory").map { v => options ++= Seq("--executor-memory", v) } desc.schedulerProperties.get("spark.cores.max").map { v => options ++= Seq("--total-executor-cores", v) } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + val formattedFiles = pyFiles.split(",") + .map { path => new File(sandboxPath, path.split("/").last).toString() } + .mkString(",") + options ++= Seq("--py-files", formattedFiles) + } + desc.schedulerProperties + .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } + .foreach { case (key, value) => options ++= Seq("--conf", Seq(key, "=\"", value, "\"").mkString("")) } options } @@ -490,10 +518,10 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.map { o => + val currentOffers = offers.asScala.map(o => new ResourceOffer( o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - }.toList + ).toList logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() @@ -504,33 +532,36 @@ private[spark] class MesosClusterScheduler( val driversToRetry = pendingRetryDrivers.filter { d => d.retryState.get.nextRetry.before(currentTime) } + scheduleTasks( - driversToRetry, + copyBuffer(driversToRetry), removeFromPendingRetryDrivers, currentOffers, tasks) + // Then we walk through the queued drivers and try to schedule them. scheduleTasks( - queuedDrivers, + copyBuffer(queuedDrivers), removeFromQueuedDrivers, currentOffers, tasks) } - tasks.foreach { case (offerId, tasks) => - driver.launchTasks(Collections.singleton(offerId), tasks) + tasks.foreach { case (offerId, taskInfos) => + driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers + offers.asScala .filter(o => !tasks.keySet.contains(o.getId)) .foreach(o => driver.declineOffer(o.getId)) } + private def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + def getSchedulerState(): MesosClusterSchedulerState = { - def copyBuffer( - buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { - val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) - buffer.copyToBuffer(newBuffer) - newBuffer - } stateLock.synchronized { new MesosClusterSchedulerState( frameworkId, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3f63ec1c5832..281965a5981b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{ArrayList => JArrayList, Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} @@ -32,7 +32,6 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils - /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -64,12 +63,21 @@ private[spark] class MesosSchedulerBackend( private[this] val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + @volatile var appId: String = _ override def start() { classLoader = Thread.currentThread.getContextClassLoader val driver = createSchedulerDriver( - master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + MesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } @@ -127,12 +135,15 @@ private[spark] class MesosSchedulerBackend( } val builder = MesosExecutorInfo.newBuilder() val (resourcesAfterCpu, usedCpuResources) = - partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + + builder.addAllResources(usedCpuResources.asJava) + builder.addAllResources(usedMemResources.asJava) + + sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) - builder.addAllResources(usedCpuResources) - builder.addAllResources(usedMemResources) val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) @@ -143,7 +154,7 @@ private[spark] class MesosSchedulerBackend( .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - (executorInfo.build(), resourcesAfterMem) + (executorInfo.build(), resourcesAfterMem.asJava) } /** @@ -188,7 +199,7 @@ private[spark] class MesosSchedulerBackend( private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { val builder = new StringBuilder - tasks.foreach { t => + tasks.asScala.foreach { t => builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") .append("Task resources: ").append(t.getResourcesList).append("\n") @@ -205,29 +216,47 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.partition { o => + // Fail first on offers with unmet constraints + val (offersMatchingConstraints, offersNotMatchingConstraints) = + offers.asScala.partition { o => + val offerAttributes = toAttributeMap(o.getAttributesList) + val meetsConstraints = + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + + // add some debug messaging + if (!meetsConstraints) { + val id = o.getId.getValue + logDebug(s"Declining offer: $id with attributes: $offerAttributes") + } + + meetsConstraints + } + + // These offers do not meet constraints. We don't need to see them again. + // Decline the offer for a long period of time. + offersNotMatchingConstraints.foreach { o => + d.declineOffer(o.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + } + + // Of the matching constraints, see which ones give us enough memory and cores + val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + // check offers for + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) - val meetsRequirements = - (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (meetsMemoryRequirements && meetsCPURequirements) || (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) - - // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" - val id = o.getId.getValue - logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + + s"$offerAttributes mem: $mem cpu: $cpus") meetsRequirements } @@ -318,10 +347,10 @@ private[spark] class MesosSchedulerBackend( .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setExecutor(executorInfo) .setName(task.name) - .addAllResources(cpuResources) + .addAllResources(cpuResources.asJava) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() - (taskInfo, finalResources) + (taskInfo, finalResources.asJava) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { @@ -387,7 +416,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) + recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index c04920e4f587..573355ba5813 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.{List => JList} import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods and Mesos scheduler will use. + * methods the Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered @@ -137,7 +137,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { protected def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. - res.filter(_.getName == name).map(_.getScalar.getValue).sum + res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } protected def markRegistered(): Unit = { @@ -169,7 +169,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { amountToUse: Double): (List[Resource], List[Resource]) = { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] - val remainingResources = resources.map { + val remainingResources = resources.asScala.map { case r => { if (remain > 0 && r.getType == Value.Type.SCALAR && @@ -214,7 +214,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.map(attr => { + offerAttributes.asScala.map(attr => { val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -253,7 +253,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { requiredValues.map(_.toLong).exists(offerRange.contains(_)) case Some(offeredValue: Value.Set) => // check if the specified required values is a subset of offered set - requiredValues.subsetOf(offeredValue.getItemList.toSet) + requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet) case Some(textValue: Value.Text) => // check if the specified value is equal, if multiple values are specified // we succeed if any of them match. @@ -299,14 +299,13 @@ private[mesos] trait MesosSchedulerUtils extends Logging { Map() } else { try { - Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { - case (k, v) => - if (v == null || v.isEmpty) { - (k, Set[String]()) - } else { - (k, v.split(',').toSet) - } - } + splitter.split(constraintsVal).asScala.toMap.mapValues(v => + if (v == null || v.isEmpty) { + Set[String]() + } else { + v.split(',').toSet + } + ) } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) @@ -331,4 +330,14 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.executorMemory } + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + + protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 4d48fcfea44e..c633d860ae6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,6 +24,7 @@ import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -103,6 +104,9 @@ private[spark] class LocalBackend( private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) private val listenerBus = scheduler.sc.listenerBus + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } /** * Returns a list of URLs representing the user classpath. @@ -114,6 +118,8 @@ private[spark] class LocalBackend( userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) } + launcherBackend.connect() + override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) @@ -122,10 +128,12 @@ private[spark] class LocalBackend( System.currentTimeMillis, executorEndpoint.localExecutorId, new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def stop() { - localEndpoint.ask(StopExecutor) + stop(SparkAppHandle.State.FINISHED) } override def reviveOffers() { @@ -145,4 +153,13 @@ private[spark] class LocalBackend( override def applicationId(): String = appId + private def stop(finalState: SparkAppHandle.State): Unit = { + localEndpoint.ask(StopExecutor) + try { + launcherBackend.setState(finalState) + } finally { + launcherBackend.close() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f212..8d6af9cae892 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 4a5274b46b7a..ea718a0edbe7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -62,28 +61,45 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = { - // scalastyle:off classforname - Class.forName(desc.getName, false, loader) - // scalastyle:on classforname - } + override def resolveClass(desc: ObjectStreamClass): Class[_] = + try { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } catch { + case e: ClassNotFoundException => + JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e) + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } +private object JavaDeserializationStream { + val primitiveMappings = Map[String, Class[_]]( + "boolean" -> classOf[Boolean], + "byte" -> classOf[Byte], + "char" -> classOf[Char], + "short" -> classOf[Short], + "int" -> classOf[Int], + "long" -> classOf[Long], + "float" -> classOf[Float], + "double" -> classOf[Double], + "void" -> classOf[Void] + ) +} private[spark] class JavaSerializerInstance( counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0ff7562e912c..cb2ac5ea167e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,28 +17,29 @@ package org.apache.spark.serializer -import java.io.{EOFException, IOException, InputStream, OutputStream} +import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} -import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} +import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -69,7 +70,9 @@ class KryoSerializer(conf: SparkConf) private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val userRegistrators = conf.get("spark.kryo.registrator", "") + .split(',') + .filter(!_.isEmpty) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) @@ -93,6 +96,9 @@ class KryoSerializer(conf: SparkConf) for (cls <- KryoSerializer.toRegister) { kryo.register(cls) } + for ((cls, ser) <- KryoSerializer.toRegisterSerializer) { + kryo.register(cls, ser) + } // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) @@ -115,7 +121,7 @@ class KryoSerializer(conf: SparkConf) classesToRegister .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. - userRegistrator + userRegistrators .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname @@ -130,6 +136,38 @@ class KryoSerializer(conf: SparkConf) // our code override the generic serializers in Chill for things like Seq new AllScalaRegistrar().apply(kryo) + // Register types missed by Chill. + // scalastyle:off + kryo.register(classOf[Array[Tuple1[Any]]]) + kryo.register(classOf[Array[Tuple2[Any, Any]]]) + kryo.register(classOf[Array[Tuple3[Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple4[Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple5[Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple6[Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple7[Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple8[Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple9[Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + + // scalastyle:on + + kryo.register(None.getClass) + kryo.register(Nil.getClass) + kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(classOf[ArrayBuffer[Any]]) + kryo.setClassLoader(classLoader) kryo } @@ -271,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -283,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) @@ -328,17 +366,8 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[PutBlock], - classOf[GotBlock], - classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], - classOf[RoaringBitmap], - classOf[RoaringArray], - classOf[RoaringArray.Element], - classOf[Array[RoaringArray.Element]], - classOf[ArrayContainer], - classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], @@ -347,6 +376,63 @@ private[serializer] object KryoSerializer { classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) + + private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( + classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { + override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + } + override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + ret + } + } + ) +} + +private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput { + override def readLong(): Long = input.readLong() + override def readChar(): Char = input.readChar() + override def readFloat(): Float = input.readFloat() + override def readByte(): Byte = input.readByte() + override def readShort(): Short = input.readShort() + override def readUTF(): String = input.readString() // readString in kryo does utf8 + override def readInt(): Int = input.readInt() + override def readUnsignedShort(): Int = input.readShortUnsigned() + override def skipBytes(n: Int): Int = { + var remaining: Long = n + while (remaining > 0) { + val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] + input.skip(skip) + remaining -= skip + } + n + } + override def readFully(b: Array[Byte]): Unit = input.read(b) + override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) + override def readLine(): String = throw new UnsupportedOperationException("readLine") + override def readBoolean(): Boolean = input.readBoolean() + override def readUnsignedByte(): Int = input.readByteUnsigned() + override def readDouble(): Double = input.readDouble() +} + +private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput { + override def writeFloat(v: Float): Unit = output.writeFloat(v) + // There is no "readChars" counterpart, except maybe "readLine", which is not supported + override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") + override def writeDouble(v: Double): Unit = output.writeDouble(v) + override def writeUTF(s: String): Unit = output.writeString(s) // writeString in kryo does UTF8 + override def writeShort(v: Int): Unit = output.writeShort(v) + override def writeInt(v: Int): Unit = output.writeInt(v) + override def writeBoolean(v: Boolean): Unit = output.writeBoolean(v) + override def write(b: Int): Unit = output.write(b) + override def write(b: Array[Byte]): Unit = output.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len) + override def writeBytes(s: String): Unit = output.writeString(s) + override def writeChar(v: Int): Unit = output.writeChar(v.toChar) + override def writeLong(v: Long): Unit = output.writeLong(v) + override def writeByte(v: Int): Unit = output.writeByte(v) } /** @@ -373,16 +459,15 @@ private class JavaIterableWrapperSerializer override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) : java.lang.Iterable[_] = { kryo.readClassAndObject(in) match { - case scalaIterable: Iterable[_] => - scala.collection.JavaConversions.asJavaIterable(scalaIterable) - case javaIterable: java.lang.Iterable[_] => - javaIterable + case scalaIterable: Iterable[_] => scalaIterable.asJava + case javaIterable: java.lang.Iterable[_] => javaIterable } } } private object JavaIterableWrapperSerializer extends Logging { - // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + // The class returned by JavaConverters.asJava + // (scala.collection.convert.Wrappers$IterableWrapper). val wrapperClass = scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index a1b1e1631eaf..e2951d8a3e09 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging { /** * Find the path leading to a not serializable object. This method is modeled after OpenJDK's * serialization mechanism, and handles the following cases: - * - primitives - * - arrays of primitives - * - arrays of non-primitive objects - * - Serializable objects - * - Externalizable objects - * - writeReplace + * + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 84% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index de79fa56f017..b0abda4a81b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,19 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle -import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +/** + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, @@ -33,9 +36,6 @@ private[spark] class HashShuffleReader[K, C]( mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { - require(endPartition == startPartition + 1, - "Hash shuffle currently only supports fetching one partition") - private val dep = handle.dependency /** Read the combined key-values for this reduce task */ @@ -44,7 +44,7 @@ private[spark] class HashShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) @@ -98,11 +98,14 @@ private[spark] class HashShuffleReader[K, C]( case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. - val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) - sorter.iterator + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index f6a96d81e7aa..cc5f933393ad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,21 +17,17 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,54 +39,26 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,37 +72,20 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) - // Because of previous failures, the shuffle file may already exist on this machine. - // If so, remove it. - if (blockFile.exists) { - if (blockFile.delete()) { - logInfo(s"Removed existing shuffle file $blockFile") - } else { - logWarning(s"Failed to remove existing shuffle file $blockFile") - } - } - blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, - writeMetrics) + val tmp = Utils.tempFileWith(blockFile) + blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with @@ -142,58 +93,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,14 +116,11 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + val file = blockManager.diskBlockManager.getFile(blockId) + if (!file.delete()) { + logWarning(s"Error deleting ${file.getPath()}") } } logInfo("Deleted all files for shuffle " + shuffleId) @@ -227,10 +131,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -239,59 +139,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fae69551e733..fadb8fe7ed0a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,13 +21,12 @@ import java.io._ import com.google.common.io.ByteStreams -import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils - -import IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.{SparkEnv, Logging, SparkConf} /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -40,11 +39,15 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). -private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver { +private[spark] class IndexShuffleBlockResolver( + conf: SparkConf, + _blockManager: BlockManager = null) + extends ShuffleBlockResolver + with Logging { - private lazy val blockManager = SparkEnv.get.blockManager + private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) @@ -60,23 +63,82 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting data ${file.getPath()}") + } } file = getIndexFile(shuffleId, mapId) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting index ${file.getPath()}") + } + } + } + + /** + * Check whether the given index and data files match each other. + * If so, return the partition lengths in the data file. Otherwise return null. + */ + private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { + // the index file should have `block + 1` longs as offset. + if (index.length() != (blocks + 1) * 8) { + return null + } + val lengths = new Array[Long](blocks) + // Read the lengths of blocks + val in = try { + new DataInputStream(new BufferedInputStream(new FileInputStream(index))) + } catch { + case e: IOException => + return null + } + try { + // Convert the offsets into lengths of each block + var offset = in.readLong() + if (offset != 0L) { + return null + } + var i = 0 + while (i < blocks) { + val off = in.readLong() + lengths(i) = off - offset + offset = off + i += 1 + } + } catch { + case e: IOException => + return null + } finally { + in.close() + } + + // the size of data file should match with index file + if (data.length() == lengths.sum) { + lengths + } else { + null } } /** * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockLocation to figure out where each block + * end of the output file. This will be used by getBlockData to figure out where each block * begins and ends. + * + * It will commit the data and index file as an atomic operation, use the existing ones, or + * replace them with new ones. + * + * Note: the `lengths` will be updated to match the existing index file if use the existing ones. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Int, + lengths: Array[Long], + dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + val indexTmp = Utils.tempFileWith(indexFile) + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -88,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } { out.close() } + + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } + } + } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { @@ -114,9 +207,8 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. + // No-op reduce ID used in interactions with disk store. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. - // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. val NOOP_REDUCE_ID = 0 } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 978366d1a1d1..a3444bf4daa3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -28,6 +28,10 @@ import org.apache.spark.{TaskContext, ShuffleDependency} * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + /** Return short name for the ShuffleManager */ + val shortName: String + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala deleted file mode 100644 index 00c1e078a441..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import scala.collection.mutable - -import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} - -/** - * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling - * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory - * from this pool and release it as it spills data out. When a task ends, all its memory will be - * released by the Executor. - * - * This class tries to ensure that each task gets a reasonable share of memory, instead of some - * task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access on "this" to mutate state and using - * wait() and notifyAll() to signal changes. - */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes - - def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) - - private def currentTaskAttemptId(): Long = { - // In case this is called on the driver, return an invalid task attempt id. - Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) - } - - /** - * Try to acquire up to numBytes memory for the current task, and return the number of bytes - * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active tasks) before it is forced to spill. This can - * happen if the number of tasks increases but an older task had a lot of memory already. - */ - def tryToAcquire(numBytes: Long): Long = synchronized { - val taskAttemptId = currentTaskAttemptId() - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!taskMemory.contains(taskAttemptId)) { - taskMemory(taskAttemptId) = 0L - notifyAll() // Will later cause waiting tasks to wake up and check numThreads again - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - while (true) { - val numActiveTasks = taskMemory.keys.size - val curMem = taskMemory(taskAttemptId) - val freeMemory = maxMemory - taskMemory.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - - if (curMem < maxMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") - wait() - } - } else { - // Only give it as much memory as is free, which might be none if it reached 1 / numThreads - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant - } - } - 0L // Never reached - } - - /** Release numBytes bytes for the current task. */ - def release(numBytes: Long): Unit = synchronized { - val taskAttemptId = currentTaskAttemptId() - val curMem = taskMemory.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") - } - taskMemory(taskAttemptId) -= numBytes - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed - } - - /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisTask(): Unit = synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.remove(taskAttemptId) - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed - } - - /** Returns the memory consumption, in bytes, for the current task */ - def getMemoryConsumptionForThisTask(): Long = synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.getOrElse(taskAttemptId, 0L) - } -} - -private object ShuffleMemoryManager { - /** - * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction - * of the memory pool and a safety factor since collections can sometimes grow bigger than - * the size we target before we estimate their sizes again. - */ - def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409d..4f30da0878ee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,10 +24,18 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) + override val shortName: String = "hash" + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, @@ -45,7 +53,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 41df70c602c3..412bf70000da 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.hash +import java.io.IOException + import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus @@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } + // rename all shuffle files to final paths + // Note: there is only one ShuffleBlockResolver in executor + shuffleBlockResolver.synchronized { + shuffle.writers.zipWithIndex.foreach { case (writer, i) => + val output = blockManager.diskBlockManager.getFile(writer.blockId) + if (sizes(i) > 0) { + if (output.exists()) { + // Use length of existing file and delete our own temporary one + sizes(i) = output.length() + writer.file.delete() + } else { + // Commit by renaming our temporary file to something the fetcher expects + if (!writer.file.renameTo(output)) { + throw new IOException(s"fail to rename ${writer.file} to $output") + } + } + } else { + if (output.exists()) { + output.delete() + } + } + } + } MapStatus(blockManager.shuffleServerId, sizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b..9b1a27952842 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,14 +19,69 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +/** + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * Sort-based shuffle has two different write paths for producing its map output files: + * + * - Serialized sorting: used when all three of the following conditions hold: + * 1. The shuffle dependency specifies no aggregation or output ordering. + * 2. The shuffle serializer supports relocation of serialized values (this is currently + * supported by KryoSerializer and Spark SQL's custom serializers). + * 3. The shuffle produces fewer than 16777216 output partitions. + * - Deserialized sorting: used to handle all other cases. + * + * ----------------------- + * Serialized sorting mode + * ----------------------- + * + * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the + * shuffle writer and are buffered in a serialized form during sorting. This write path implements + * several optimizations: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on these optimizations, see SPARK-7081. + */ +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ + private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + + override val shortName: String = "sort" + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -35,7 +90,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) + if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } } /** @@ -47,38 +117,113 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] - shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) - new SortShuffleWriter( - shuffleBlockResolver, baseShuffleHandle, mapId, context) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + numMapsForShuffle.putIfAbsent( + handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val env = SparkEnv.get + handle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + bypassMergeSortHandle, + mapId, + context, + env.conf) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shuffleMapNumber.containsKey(shuffleId)) { - val numMaps = shuffleMapNumber.remove(shuffleId) - (0 until numMaps).map{ mapId => + Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - indexShuffleBlockResolver - } - /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() } } + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * buffering map outputs in a serialized form. This is an extreme defensive programming measure, + * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. + * */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + log.debug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1c..f83cf8859e58 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,9 +20,9 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -36,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: SortShuffleFileWriter[K, V] = null + private var sorter: ExternalSorter[K, V, _] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -53,33 +53,24 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, - writeMetrics, Serializer.getSerializer(dep.serializer)) + context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. new ExternalSorter[K, V, V]( - aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) + context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val tmp = Utils.tempFileWith(output) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) - shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -111,12 +102,14 @@ private[spark] class SortShuffleWriter[K, V, C]( } private[spark] object SortShuffleWriter { - def shouldBypassMergeSort( - conf: SparkConf, - numPartitions: Int, - aggregator: Option[Aggregator[_, _, _]], - keyOrdering: Option[Ordering[_]]): Boolean = { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") + false + } else { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + dep.partitioner.numPartitions <= bypassMergeThreshold + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala deleted file mode 100644 index df7bbd64247d..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. - */ -private[spark] class UnsafeShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object UnsafeShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. - */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - val shufId = dependency.shuffleId - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") - false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") - false - } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") - true - } - } -} - -/** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - No individual record is larger than 128 MB when serialized. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. - * In sort-based shuffle, incoming records are sorted according to their target partition ids, then - * written to a single map output file. Reducers fetch contiguous regions of this file in order to - * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged - * to produce the final output file. - * - * UnsafeShuffleManager optimizes this process in several ways: - * - * - Its sort operates on serialized binary data rather than Java objects, which reduces memory - * consumption and GC overheads. This optimization requires the record serializer to have certain - * properties to allow serialized records to be re-ordered without requiring deserialization. - * See SPARK-4550, where this optimization was first proposed and implemented, for more details. - * - * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts - * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per - * record in the sorting array, this fits more of the array into cache. - * - * - The spill merging procedure operates on blocks of serialized records that belong to the same - * partition and does not need to deserialize records during the merge. - * - * - When the spill compression codec supports concatenation of compressed data, the spill merge - * simply concatenates the serialized and compressed spill partitions to produce the final output - * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used - * and avoids the need to allocate decompression or copying buffers during the merge. - * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. - */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + - "manager; its optimized shuffles will continue to spill to disk when necessary.") - } - - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - private[this] val shufflesThatFellBackToSortShuffle = - Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) - private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() - - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - sortShuffleManager.getReader(handle, startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { - handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => - numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) - val env = SparkEnv.get - new UnsafeShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - context.taskMemoryManager(), - env.shuffleMemoryManager, - unsafeShuffleHandle, - mapId, - context, - env.conf) - case other => - shufflesThatFellBackToSortShuffle.add(handle.shuffleId) - sortShuffleManager.getWriter(handle, mapId, context) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { - sortShuffleManager.unregisterShuffle(shuffleId) - } else { - Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - } - - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - sortShuffleManager.shuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - sortShuffleManager.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b..31b4dd7c0f42 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.{GET, PathParam, Produces, QueryParam} import javax.ws.rs.core.MediaType +import javax.ws.rs.{GET, Produces, QueryParam} import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} @@ -59,6 +59,15 @@ private[v1] object AllStagesResource { stageUiData: StageUIData, includeDetails: Boolean): StageData = { + val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + + val firstTaskLaunchedTime: Option[Date] = + if (taskLaunchTimes.nonEmpty) { + Some(new Date(taskLaunchTimes.min)) + } else { + None + } + val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) } else { @@ -92,6 +101,9 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + submissionTime = stageInfo.submissionTime.map(new Date(_)), + firstTaskLaunchedTime, + completionTime = stageInfo.completionTime.map(new Date(_)), inputBytes = stageUiData.inputBytes, inputRecords = stageUiData.inputRecords, outputBytes = stageUiData.outputBytes, @@ -127,7 +139,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 17b521f3e1d4..0fc0fb59d861 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -62,6 +62,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = app.id, name = app.name, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = app.attempts.map { internalAttemptInfo => new ApplicationAttemptInfo( attemptId = internalAttemptInfo.attemptId, @@ -81,6 +85,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = internal.id, name = internal.desc.name, + coresGranted = Some(internal.coresGranted), + maxCores = internal.desc.maxCores, + coresPerExecutor = internal.desc.coresPerExecutor, + memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(internal.startTime), diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2bec64f2ef02..5feb1dc2e5b7 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -25,6 +25,10 @@ import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( val id: String, val name: String, + val coresGranted: Option[Int], + val maxCores: Option[Int], + val coresPerExecutor: Option[Int], + val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) class ApplicationAttemptInfo private[spark]( @@ -116,6 +120,9 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val submissionTime: Option[Date], + val firstTaskLaunchedTime: Option[Date], + val completionTime: Option[Date], val inputBytes: Long, val inputRecords: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala new file mode 100644 index 000000000000..f6e46ae9a481 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkException + +private[spark] +case class BlockFetchException(messages: String, throwable: Throwable) + extends SparkException(messages, throwable) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 86493673d958..6074fc58d70d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,24 +21,25 @@ import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.Random +import scala.util.control.NonFatal import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.MemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{SerializerInstance, Serializer} +import org.apache.spark.serializer.{Serializer, SerializerInstance} import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -63,8 +64,8 @@ private[spark] class BlockManager( rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, - maxMemory: Long, val conf: SparkConf, + memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, @@ -81,28 +82,35 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false - private[spark] val memoryStore = new MemoryStore(this, maxMemory) + private[spark] val memoryStore = new MemoryStore(this, memoryManager) private[spark] val diskStore = new DiskStore(this, diskBlockManager) private[spark] lazy val externalBlockStore: ExternalBlockStore = { externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) } + memoryManager.setMemoryStore(memoryStore) + + // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. + // However, since we use this only for reporting and logging, what we actually want here is + // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need + // to revisit whether reporting this value as the "max" is intuitive to the user. + private val maxMemory = memoryManager.maxStorageMemory private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = - Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt - - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") + private val externalShuffleServicePort = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get("spark.shuffle.service.port").toInt + } else { + tmpPort + } } var blockManagerId: BlockManagerId = _ @@ -114,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { @@ -156,24 +164,6 @@ private[spark] class BlockManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this( - execId: String, - rpcEnv: RpcEnv, - master: BlockManagerMaster, - serializer: Serializer, - conf: SparkConf, - mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService, - securityManager: SecurityManager, - numUsableCores: Int) = { - this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - } - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -191,6 +181,7 @@ private[spark] class BlockManager( executorId, blockTransferService.hostName, blockTransferService.port) shuffleServerId = if (externalShuffleServiceEnabled) { + logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId @@ -209,7 +200,7 @@ private[spark] class BlockManager( val shuffleConfig = new ExecutorShuffleInfo( diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, - shuffleManager.getClass.getName) + shuffleManager.shortName) val MAX_ATTEMPTS = 3 val SLEEP_TIME_SECS = 5 @@ -222,7 +213,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } @@ -590,10 +581,26 @@ private[spark] class BlockManager( private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) + var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + val data = try { + blockTransferService.fetchBlockSync( + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + } catch { + case NonFatal(e) => + numFetchFailures += 1 + if (numFetchFailures == locations.size) { + // An exception is thrown while fetching this block from all locations + throw new BlockFetchException(s"Failed to fetch block from" + + s" ${locations.size} locations. Most recent failure cause:", e) + } else { + // This location failed, so we retry fetch from a different one by returning null here + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $numFetchFailures)", e) + null + } + } if (data != null) { if (asBlockResult) { @@ -651,8 +658,8 @@ private[spark] class BlockManager( writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, - syncWrites, writeMetrics) + new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, + syncWrites, writeMetrics, blockId) } /** @@ -1183,45 +1190,38 @@ private[spark] class BlockManager( def dataSerializeStream( blockId: BlockId, outputStream: OutputStream, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): Unit = { + values: Iterator[Any]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = serializer.newInstance() + val ser = defaultSerializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a byte buffer. */ - def dataSerialize( - blockId: BlockId, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) - dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + def dataSerialize(blockId: BlockId, values: Iterator[Any]): ByteBuffer = { + val byteStream = new ByteBufferOutputStream(4096) + dataSerializeStream(blockId, byteStream, values) + byteStream.toByteBuffer } /** * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize( - blockId: BlockId, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserialize(blockId: BlockId, bytes: ByteBuffer): Iterator[Any] = { bytes.rewind() - dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true)) } /** * Deserializes a InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream( - blockId: BlockId, - inputStream: InputStream, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserializeStream(blockId: BlockId, inputStream: InputStream): Iterator[Any] = { val stream = new BufferedInputStream(inputStream) - serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator + defaultSerializer + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator } def stop(): Unit = { @@ -1249,13 +1249,6 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { private val ID_GENERATOR = new IdGenerator - /** Return the total amount of storage memory available. */ - private def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) - val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f70f701494db..440c4c18aadd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -69,8 +69,9 @@ class BlockManagerMaster( } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + GetLocationsMultipleBlockIds(blockIds)) } /** @@ -86,8 +87,8 @@ class BlockManagerMaster( driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { + driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** @@ -103,7 +104,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -115,7 +116,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -129,7 +130,7 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5dc0c537cbb6..41892b4ffce5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,9 +19,8 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} @@ -75,8 +74,8 @@ class BlockManagerMasterEndpoint( case GetPeers(blockManagerId) => context.reply(getPeers(blockManagerId)) - case GetRpcHostPortForExecutor(executorId) => - context.reply(getRpcHostPortForExecutor(executorId)) + case GetExecutorEndpointRef(executorId) => + context.reply(getExecutorEndpointRef(executorId)) case GetMemoryStatus => context.reply(memoryStatus) @@ -133,7 +132,7 @@ class BlockManagerMasterEndpoint( // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -242,7 +241,7 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) + new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) }.toArray } @@ -292,7 +291,7 @@ class BlockManagerMasterEndpoint( if (askSlaves) { info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { - Future { info.blocks.keys.filter(filter).toSeq } + Future { info.blocks.asScala.keys.filter(filter).toSeq } } future } @@ -372,7 +371,8 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds( + blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } @@ -387,15 +387,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its - * [[BlockManagerSlaveEndpoint]]. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. */ - private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + info.slaveEndpoint } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843..f392a4a0cd9b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,6 +42,11 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + /** + * Driver -> Executor message to trigger a thread dump. + */ + case object TriggerThreadDump extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +95,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 7478ab0fc2f7..9eca902f7454 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,10 +19,10 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, @@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends RpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") @@ -70,6 +70,9 @@ class BlockManagerSlaveEndpoint( case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) + + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { @@ -80,7 +83,7 @@ class BlockManagerSlaveEndpoint( future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) - logDebug("Sent response: " + response + " to " + context.sender) + logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 5f537692a16c..f7e84a2c2e14 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -133,7 +133,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") - Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { @@ -145,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -155,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -165,11 +164,15 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 49d9154f95a5..e2dd80f24393 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -34,15 +34,15 @@ import org.apache.spark.util.Utils * reopened again. */ private[spark] class DiskBlockObjectWriter( - val blockId: BlockId, - file: File, + val file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics) + writeMetrics: ShuffleWriteMetrics, + val blockId: BlockId = null) extends OutputStream with Logging { @@ -144,8 +144,10 @@ private[spark] class DiskBlockObjectWriter( * Reverts writes that haven't been flushed yet. Callers should invoke this function * when there are runtime exceptions. This method will not throw, though it may be * unsuccessful in truncating written data. + * + * @return the file that this DiskBlockObjectWriter wrote to. */ - def revertPartialWritesAndClose() { + def revertPartialWritesAndClose(): File = { // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. try { @@ -160,12 +162,14 @@ private[spark] class DiskBlockObjectWriter( val truncateStream = new FileOutputStream(file, true) try { truncateStream.getChannel.truncate(initialPosition) + file } finally { truncateStream.close() } } catch { case e: Exception => logError("Uncaught exception while reverting partial writes to file " + file, e) + file } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f4595628216..6c4477184d5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -86,7 +86,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } catch { case e: Throwable => if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } } throw e } @@ -142,23 +144,14 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - // TODO: Should bypass getBytes and use a stream based implementation, so that - // we won't use a lot of memory during e.g. external sort merge. - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { - file.delete() + val ret = file.delete() + if (!ret) { + logWarning(s"Error deleting ${file.getPath()}") + } + ret } else { false } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index db965d54bafd..94883a54a74e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -177,15 +177,6 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - }) - } - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. private def createBlkManager(): Option[ExternalBlockManager] = { val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) @@ -196,7 +187,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) - addShutdownHook(); + ShutdownHookManager.addShutdownHook { () => + logDebug("Shutdown hook called") + externalBlockManager.map(_.shutdown()) + } Some(instance) } catch { case NonFatal(t) => diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 6f27f00307f8..bdab8c2332fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.memory.MemoryManager import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -33,19 +34,17 @@ private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ -private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) +private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager) extends BlockStore(blockManager) { + // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and + // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! + private val conf = blockManager.conf private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) - @volatile private var currentMemory = 0L - - // Ensure only one thread is putting, and if necessary, dropping blocks at any given time - private val accountingLock = new Object - // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) - // All accesses of this map are assumed to have manually synchronized on `accountingLock` + // All accesses of this map are assumed to have manually synchronized on `memoryManager` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. // Pending unroll memory refers to the intermediate memory occupied by a task @@ -56,19 +55,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // memory (SPARK-4777). private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() - /** - * The amount of space ensured for unrolling values in memory, shared across all cores. - * This space is not reserved in advance, but allocated dynamically by dropping existing blocks. - */ - private val maxUnrollMemory: Long = { - val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2) - (maxMemory * unrollFraction).toLong - } - // Initial memory to request before unrolling any block private val unrollMemoryThreshold: Long = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + /** Total amount of memory available for storage, in bytes. */ + private def maxMemory: Long = memoryManager.maxStorageMemory + if (maxMemory < unrollMemoryThreshold) { logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + @@ -77,8 +70,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) - /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ - def freeMemory: Long = maxMemory - currentMemory + /** Total storage memory used including unroll memory, in bytes. */ + private def memoryUsed: Long = memoryManager.storageMemoryUsed + + /** + * Amount of storage memory, in bytes, used for caching blocks. + * This does not include memory used for unrolling. + */ + private def blocksMemoryUsed: Long = memoryManager.synchronized { + memoryUsed - currentUnrollMemory + } override def getSize(blockId: BlockId): Long = { entries.synchronized { @@ -94,8 +95,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -108,15 +110,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks) val data = - if (putAttempt.success) { + if (putSuccess) { assert(bytes.limit == size) Right(bytes.duplicate()) } else { null } - PutResult(size, data, putAttempt.droppedBlocks) + PutResult(size, data, droppedBlocks) } override def putArray( @@ -124,14 +127,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - val putAttempt = tryToPut(blockId, values, sizeEstimate, deserialized = true) - PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks) + tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks) + PutResult(sizeEstimate, Left(values.iterator), droppedBlocks) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -208,24 +212,25 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: BlockId): Boolean = { - entries.synchronized { - val entry = entries.remove(blockId) - if (entry != null) { - currentMemory -= entry.size - logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") - true - } else { - false - } + override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { + val entry = entries.synchronized { entries.remove(blockId) } + if (entry != null) { + memoryManager.releaseStorageMemory(entry.size) + logDebug(s"Block $blockId of size ${entry.size} dropped " + + s"from memory (free ${maxMemory - blocksMemoryUsed})") + true + } else { + false } } - override def clear() { + override def clear(): Unit = memoryManager.synchronized { entries.synchronized { entries.clear() - currentMemory = 0 } + unrollMemoryMap.clear() + pendingUnrollMemoryMap.clear() + memoryManager.releaseAllStorageMemory() logInfo("MemoryStore cleared") } @@ -265,7 +270,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -281,20 +286,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - // Hold the accounting lock, in case another thread concurrently puts a block that - // takes up the unrolling space we just ensured here - accountingLock.synchronized { - if (!reserveUnrollMemoryForThisTask(amountToRequest)) { - // If the first request is not granted, try again after ensuring free space - // If there is still not enough space, give up and drop the partition - val spaceToEnsure = maxUnrollMemory - currentUnrollMemory - if (spaceToEnsure > 0) { - val result = ensureFreeSpace(blockId, spaceToEnsure) - droppedBlocks ++= result.droppedBlocks - } - keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) - } - } + keepUnrolling = reserveUnrollMemoryForThisTask( + blockId, amountToRequest, droppedBlocks) // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -312,16 +305,23 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } finally { - // If we return an array, the values returned will later be cached in `tryToPut`. - // In this case, we should release the memory after we cache the block there. - // Otherwise, if we return an iterator, we release the memory reserved here - // later when the task finishes. + // If we return an array, the values returned here will be cached in `tryToPut` later. + // In this case, we should release the memory only after we cache the block there. if (keepUnrolling) { - accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved - releaseUnrollMemoryForThisTask(amountToRelease) - reservePendingUnrollMemoryForThisTask(amountToRelease) + val taskAttemptId = currentTaskAttemptId() + memoryManager.synchronized { + // Since we continue to hold onto the array until we actually cache it, we cannot + // release the unroll memory yet. Instead, we transfer it to pending unroll memory + // so `tryToPut` can further transfer it to normal storage memory later. + // TODO: we can probably express this without pending unroll memory (SPARK-10907) + val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved + unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending } + } else { + // Otherwise, if we return an iterator, we can only release the unroll memory when + // the task finishes since we don't know when the iterator will be consumed. } } } @@ -337,8 +337,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId: BlockId, value: Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { - tryToPut(blockId, () => value, size, deserialized) + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + tryToPut(blockId, () => value, size, deserialized, droppedBlocks) } /** @@ -349,18 +350,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be * created to avoid OOM since it may be a big ByteBuffer. * - * Synchronize on `accountingLock` to ensure that all the put requests and its associated block + * Synchronize on `memoryManager` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for * another block. * - * Return whether put was successful, along with the blocks dropped in the process. + * All blocks evicted in the process, if any, will be added to `droppedBlocks`. + * + * @return whether put was successful. */ private def tryToPut( blockId: BlockId, value: () => Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has @@ -368,24 +372,24 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * for freeing up more space for another block that needs to be put. Only then the actually * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */ - var putSuccess = false - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - - accountingLock.synchronized { - val freeSpaceResult = ensureFreeSpace(blockId, size) - val enoughFreeSpace = freeSpaceResult.success - droppedBlocks ++= freeSpaceResult.droppedBlocks - - if (enoughFreeSpace) { + memoryManager.synchronized { + // Note: if we have previously unrolled this block successfully, then pending unroll + // memory should be non-zero. This is the amount that we already reserved during the + // unrolling process. In this case, we can just reuse this space to cache our block. + // The synchronization on `memoryManager` here guarantees that the release and acquire + // happen atomically. This relies on the assumption that all memory acquisitions are + // synchronized on the same lock. + releasePendingUnrollMemoryForThisTask() + val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) + if (enoughMemory) { + // We acquired enough memory for the block, so go ahead and put it val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) - currentMemory += size } val valuesOrBytes = if (deserialized) "values" else "bytes" logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - putSuccess = true + blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. @@ -397,62 +401,46 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } - // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisTask() + enoughMemory } - ResultWithDroppedBlocks(putSuccess, droppedBlocks) } /** - * Try to free up a given amount of space to store a particular block, but can fail if - * either the block is bigger than our memory or it would require replacing another block - * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping - * blocks. Otherwise, the freed space may fill up before the caller puts in their new value. - * - * Return whether there is enough free space, along with the blocks dropped in the process. - */ - private def ensureFreeSpace( - blockIdToAdd: BlockId, - space: Long): ResultWithDroppedBlocks = { - logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") - - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - - if (space > maxMemory) { - logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit") - return ResultWithDroppedBlocks(success = false, droppedBlocks) - } - - // Take into account the amount of memory currently occupied by unrolling blocks - // and minus the pending unroll memory for that block on current thread. - val taskAttemptId = currentTaskAttemptId() - val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) - - if (actualFreeMemory < space) { - val rddToAdd = getRddId(blockIdToAdd) + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def evictBlocksToFreeSpace( + blockId: Option[BlockId], + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + assert(space > 0) + memoryManager.synchronized { + var freedMemory = 0L + val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L - // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (actualFreeMemory + selectedMemory < space && iterator.hasNext) { + while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { selectedBlocks += blockId - selectedMemory += pair.getValue.size + freedMemory += pair.getValue.size } } } - if (actualFreeMemory + selectedMemory >= space) { + if (freedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } @@ -469,14 +457,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } } - return ResultWithDroppedBlocks(success = true, droppedBlocks) + true } else { - logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " + - "from the same RDD") - return ResultWithDroppedBlocks(success = false, droppedBlocks) + blockId.foreach { id => + logInfo(s"Will not store $id as it would require dropping another block " + + "from the same RDD") + } + false } } - ResultWithDroppedBlocks(success = true, droppedBlocks) } override def contains(blockId: BlockId): Boolean = { @@ -489,17 +478,20 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Reserve additional memory for unrolling blocks used by this task. - * Return whether the request is granted. + * Reserve memory for unrolling the given block for this task. + * @return whether the request is granted. */ - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - accountingLock.synchronized { - val granted = freeMemory > currentUnrollMemory + memory - if (granted) { + def reserveUnrollMemoryForThisTask( + blockId: BlockId, + memory: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + memoryManager.synchronized { + val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) + if (success) { val taskAttemptId = currentTaskAttemptId() unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } - granted + success } } @@ -507,73 +499,68 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this task for unrolling blocks. * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - if (memory < 0) { - unrollMemoryMap.remove(taskAttemptId) - } else { - unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory - // If this task claims no more unroll memory, release it completely - if (unrollMemoryMap(taskAttemptId) <= 0) { - unrollMemoryMap.remove(taskAttemptId) + memoryManager.synchronized { + if (unrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + unrollMemoryMap(taskAttemptId) -= memoryToRelease + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) } } } } - /** - * Reserve the unroll memory of current unroll successful block used by this task - * until actually put the block into memory entry. - */ - def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { - val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory - } - } - /** * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisTask(): Unit = { + def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - pendingUnrollMemoryMap.remove(taskAttemptId) + memoryManager.synchronized { + if (pendingUnrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease + if (pendingUnrollMemoryMap(taskAttemptId) == 0) { + pendingUnrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) + } + } } } /** * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ - def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = memoryManager.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** * Return the number of tasks currently unrolling blocks. */ - def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. */ - def logMemoryUsage(): Unit = { - val blocksMemory = currentMemory - val unrollMemory = currentUnrollMemory - val totalMemory = blocksMemory + unrollMemory + private def logMemoryUsage(): Unit = { logInfo( - s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + - s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + + s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } @@ -584,7 +571,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * @param blockId ID of the block we are trying to unroll. * @param finalVectorSize Final size of the vector before unrolling failed. */ - def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { logWarning( s"Not enough space to cache $blockId in memory! " + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" @@ -592,7 +579,3 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logMemoryUsage() } } - -private[spark] case class ResultWithDroppedBlocks( - success: Boolean, - droppedBlocks: Seq[(BlockId, BlockStatus)]) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 96062626b504..94e8559bd2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDDOperationScope, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} @DeveloperApi class RDDInfo( @@ -28,6 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], + val callSite: String = "", val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -56,6 +57,7 @@ private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a759ceb96ec1..0d0448feb5b0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -260,10 +260,7 @@ final class ShuffleBlockFetcherIterator( fetchRequests ++= Utils.randomize(remoteRequests) // Send out initial requests for blocks, up to our maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) @@ -296,10 +293,7 @@ final class ShuffleBlockFetcherIterator( case _ => } // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() result match { case FailureFetchResult(blockId, address, e) => @@ -315,6 +309,14 @@ final class ShuffleBlockFetcherIterator( } } + private def fetchUpToMaxBytes(): Unit = { + // Send fetch requests up to maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + } + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index b53c86e89a27..d14fe4613528 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -27,11 +27,12 @@ import scala.util.control.NonFatal import com.google.common.io.ByteStreams import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} +import tachyon.conf.TachyonConf import tachyon.TachyonURI -import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -60,7 +61,11 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log rootDirs = s"$storeDir/$appFolderName/$executorId" master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") - client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null + client = if (master != null && master != "") { + TachyonFS.get(new TachyonURI(master), new TachyonConf()) + } else { + null + } // original implementation call System.exit, we change it to run without extblkstore support if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -75,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -98,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + Utils.writeByteBuffer(bytes, os) } catch { case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) @@ -235,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c8356467fab8..b796a44fe01a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ @@ -106,7 +118,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -121,7 +137,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -246,11 +262,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 17d7b39c2d95..6e2375477a68 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -159,9 +159,9 @@ private[ui] trait PagedTable[T] { // "goButtonJsFuncName" val formJs = s"""$$(function(){ - | $$( "#form-task-page" ).submit(function(event) { - | var page = $$("#form-task-page-no").val() - | var pageSize = $$("#form-task-page-size").val() + | $$( "#form-$tableId-page" ).submit(function(event) { + | var page = $$("#form-$tableId-page-no").val() + | var pageSize = $$("#form-$tableId-page-size").val() | pageSize = pageSize ? pageSize: 100; | if (page != "") { | ${goButtonJsFuncName}(page, pageSize); @@ -173,12 +173,14 @@ private[ui] trait PagedTable[T] {
    -
    + - + - - + +
    diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39b..8da6884a3853 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -56,6 +59,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -64,20 +69,19 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,13 +98,17 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(startTime), @@ -149,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index e2d25e36365f..cb122eaed83d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -62,6 +62,13 @@ private[spark] object ToolTips { """Time that the executor spent paused for Java garbage collection while the task was running.""" + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + val JOB_TIMELINE = """Shows when jobs started and ended and when executors joined or left. Drag to scroll. Click Enable Zooming and use mouse wheel to zoom in/out.""" diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 718aea7e1dc2..81a6f07ec836 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -29,6 +31,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph private[spark] object UIUtils extends Logging { val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" + val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -140,14 +143,10 @@ private[spark] object UIUtils extends Logging { // Yarn has to go through a proxy so the base uri is provided and has to be on all links def uiRoot: String = { - if (System.getenv("APPLICATION_WEB_PROXY_BASE") != null) { - System.getenv("APPLICATION_WEB_PROXY_BASE") - } else if (System.getProperty("spark.ui.proxyBase") != null) { - System.getProperty("spark.ui.proxyBase") - } - else { - "" - } + // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. + sys.props.get("spark.ui.proxyBase") + .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .getOrElse("") } def prependBaseUri(basePath: String = "", resource: String = ""): String = { @@ -211,10 +210,10 @@ private[spark] object UIUtils extends Logging { {org.apache.spark.SPARK_VERSION}
    - +
    @@ -320,7 +319,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    @@ -352,7 +353,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
    - + @@ -394,4 +396,59 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 0c854f04890b..ca37829216f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -21,8 +21,6 @@ import java.util.concurrent.TimeoutException import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import com.google.common.annotations.VisibleForTesting - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -53,8 +51,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { type PoolName = String type ExecutorId = String - // Applicatin: + // Application: @volatile var startTime = -1L + @volatile var endTime = -1L // Jobs: val activeJobs = new HashMap[JobId, JobUIData] @@ -536,6 +535,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { startTime = appStarted.time } + override def onApplicationEnd(appEnded: SparkListenerApplicationEnd) { + endTime = appEnded.time + } + /** * For testing only. Wait until at least `numExecutors` executors are up, or throw * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cf04b5e59239..1b34ba9f03c4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -26,8 +26,9 @@ import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} @@ -48,7 +49,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ("shuffle-read-time-proportion", "Shuffle Read Time"), ("executor-runtime-proportion", "Executor Computing Time"), ("shuffle-write-time-proportion", "Shuffle Write Time"), - ("serialization-time-proportion", "Result Serialization TIme"), + ("serialization-time-proportion", "Result Serialization Time"), ("getting-result-time-proportion", "Getting Result Time")) legendPairs.zipWithIndex.map { @@ -67,6 +68,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) + private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true) + + private def getLocalitySummaryString(stageData: StageUIData): String = { + val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) + val localityCounts = localities.groupBy(identity).mapValues(_.size) + val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => + val localityName = locality match { + case TaskLocality.PROCESS_LOCAL => "Process local" + case TaskLocality.NODE_LOCAL => "Node local" + case TaskLocality.RACK_LOCAL => "Rack local" + case TaskLocality.ANY => "Any" + } + s"$localityName: $count" + } + localityNamesAndCounts.sorted.mkString("; ") + } def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { @@ -114,10 +131,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val hasAccumulators = accumulables.size > 0 + + val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables + val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } + val hasAccumulators = externalAccumulables.size > 0 val summary =
    @@ -126,6 +144,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total Time Across All Tasks: {UIUtils.formatDuration(stageData.executorRunTime)} +
  • + Locality Level Summary: + {getLocalitySummaryString(stageData)} +
  • {if (stageData.hasInput) {
  • Input Size / Records: @@ -221,6 +243,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Getting Result Time
  • + {if (displayPeakExecutionMemory) { +
  • + + + Peak Execution Memory + +
  • + }}
    @@ -241,11 +272,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - accumulables.values.toSeq) + externalAccumulables.toSeq) val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( + parent.conf, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", tasks, @@ -294,12 +326,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else { def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = Distribution(data).get.getQuantiles() - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { getDistributionQuantiles(times).map { millis => {UIUtils.formatDuration(millis.toLong)} } } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorDeserializeTime.toDouble @@ -349,6 +383,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
    +: getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => + info.accumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.update.getOrElse("0").toLong } + .getOrElse(0L) + .toDouble + } + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } + // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). @@ -359,10 +410,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator @@ -466,6 +513,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, + if (displayPeakExecutionMemory) { + + {peakExecutionMemoryQuantiles} + + } else { + Nil + }, if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -499,7 +553,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() val content = summary ++ @@ -586,7 +640,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber val svgTag = if (totalExecutionTime == 0) { @@ -750,29 +804,30 @@ private[ui] case class TaskTableRowBytesSpilledData( * Contains all data that needs for sorting and generating HTML. Using this one rather than * TaskUIData to avoid creating duplicate contents during sorting the data. */ -private[ui] case class TaskTableRowData( - index: Int, - taskId: Long, - attempt: Int, - speculative: Boolean, - status: String, - taskLocality: String, - executorIdAndHost: String, - launchTime: Long, - duration: Long, - formatDuration: String, - schedulerDelay: Long, - taskDeserializationTime: Long, - gcTime: Long, - serializationTime: Long, - gettingResultTime: Long, - accumulators: Option[String], // HTML - input: Option[TaskTableRowInputData], - output: Option[TaskTableRowOutputData], - shuffleRead: Option[TaskTableRowShuffleReadData], - shuffleWrite: Option[TaskTableRowShuffleWriteData], - bytesSpilled: Option[TaskTableRowBytesSpilledData], - error: String) +private[ui] class TaskTableRowData( + val index: Int, + val taskId: Long, + val attempt: Int, + val speculative: Boolean, + val status: String, + val taskLocality: String, + val executorIdAndHost: String, + val launchTime: Long, + val duration: Long, + val formatDuration: String, + val schedulerDelay: Long, + val taskDeserializationTime: Long, + val gcTime: Long, + val serializationTime: Long, + val gettingResultTime: Long, + val peakExecutionMemoryUsed: Long, + val accumulators: Option[String], // HTML + val input: Option[TaskTableRowInputData], + val output: Option[TaskTableRowOutputData], + val shuffleRead: Option[TaskTableRowShuffleReadData], + val shuffleWrite: Option[TaskTableRowShuffleWriteData], + val bytesSpilled: Option[TaskTableRowBytesSpilledData], + val error: String) private[ui] class TaskDataSource( tasks: Seq[TaskUIData], @@ -816,10 +871,15 @@ private[ui] class TaskDataSource( val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map { acc => + val (taskInternalAccumulables, taskExternalAccumulables) = + info.accumulables.partition(_.internal) + val externalAccumulableReadable = taskExternalAccumulables.map { acc => StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") } + val peakExecutionMemoryUsed = taskInternalAccumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.update.getOrElse("0").toLong } + .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) @@ -923,10 +983,10 @@ private[ui] class TaskDataSource( None } - TaskTableRowData( + new TaskTableRowData( info.index, info.taskId, - info.attempt, + info.attemptNumber, info.speculative, info.status, info.taskLocality.toString, @@ -939,14 +999,14 @@ private[ui] class TaskDataSource( gcTime, serializationTime, gettingResultTime, - if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + peakExecutionMemoryUsed, + if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, input, output, shuffleRead, shuffleWrite, bytesSpilled, - errorMessage.getOrElse("") - ) + errorMessage.getOrElse("")) } /** @@ -1006,6 +1066,10 @@ private[ui] class TaskDataSource( override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) } + case "Peak Execution Memory" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) + } case "Accumulators" => if (hasAccumulators) { new Ordering[TaskTableRowData] { @@ -1132,6 +1196,7 @@ private[ui] class TaskDataSource( } private[ui] class TaskPagedTable( + conf: SparkConf, basePath: String, data: Seq[TaskUIData], hasAccumulators: Boolean, @@ -1143,9 +1208,12 @@ private[ui] class TaskPagedTable( currentTime: Long, pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[TaskTableRowData]{ + desc: Boolean) extends PagedTable[TaskTableRowData] { - override def tableId: String = "" + // We only track peak memory used for unsafe operators + private val displayPeakExecutionMemory = conf.getBoolean("spark.sql.unsafe.enabled", true) + + override def tableId: String = "task-table" override def tableCssClass: String = "table table-bordered table-condensed table-striped" @@ -1160,8 +1228,7 @@ private[ui] class TaskPagedTable( currentTime, pageSize, sortColumn, - desc - ) + desc) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") @@ -1195,6 +1262,13 @@ private[ui] class TaskPagedTable( ("GC Time", ""), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + { + if (displayPeakExecutionMemory) { + Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) + } else { + Nil + } + } ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ @@ -1218,7 +1292,7 @@ private[ui] class TaskPagedTable( Seq(("Errors", "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { - new IllegalArgumentException(s"Unknown column: $sortColumn") + throw new IllegalArgumentException(s"Unknown column: $sortColumn") } val headerRow: Seq[Node] = { @@ -1271,6 +1345,11 @@ private[ui] class TaskPagedTable( {UIUtils.formatDuration(task.gettingResultTime)} + {if (displayPeakExecutionMemory) { + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + + }} {if (task.accumulators.nonEmpty) { {Unparsed(task.accumulators.get)} }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 99812db4912a..2a1c3c1a50ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -17,11 +17,10 @@ package org.apache.spark.ui.jobs -import scala.xml.Node -import scala.xml.Text - import java.util.Date +import scala.xml.{Node, Text} + import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.scheduler.StageInfo @@ -116,7 +115,7 @@ private[ui] class StageTableBase( stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield { - {desc} + UIUtils.makeDescription(desc, basePathUri) }
    {stageDesc.getOrElse("")} {killLink} {nameLink} {details}
    } @@ -146,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 9bf67db8acde..d2dfc5a32915 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames { val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" + val PEAK_EXECUTION_MEMORY = "peak_execution_memory" } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index ffea9817c0b0..e9c8a8e299cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,11 +18,12 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -38,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -104,8 +105,8 @@ private[ui] object RDDOperationGraph extends Logging { edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) } // TODO: differentiate between the intention to cache an RDD and whether it's actually cached - val node = nodes.getOrElseUpdate( - rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE)) + val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode( + rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite)) if (rdd.scope.isEmpty) { // This RDD has no encompassing scope, so we put it directly in the root cluster @@ -167,7 +168,7 @@ private[ui] object RDDOperationGraph extends Logging { def makeDotFile(graph: RDDOperationGraph): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") - dotFile.append(makeDotSubgraph(graph.rootCluster, indent = " ")) + makeDotSubgraph(dotFile, graph.rootCluster, indent = " ") graph.edges.foreach { edge => dotFile.append(s""" ${edge.fromId}->${edge.toId};\n""") } dotFile.append("}") val result = dotFile.toString() @@ -177,21 +178,23 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - s"""${node.id} [label="${node.name} [${node.id}]"]""" + val label = s"${node.name} [${node.id}]\n${node.callsite}" + s"""${node.id} [label="$label"]""" } - /** Return the dot representation of a subgraph in an RDDOperationGraph. */ - private def makeDotSubgraph(cluster: RDDOperationCluster, indent: String): String = { - val subgraph = new StringBuilder - subgraph.append(indent + s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent + s""" label="${cluster.name}";\n""") + /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ + private def makeDotSubgraph( + subgraph: StringBuilder, + cluster: RDDOperationCluster, + indent: String): Unit = { + subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") + subgraph.append(indent).append(s""" label="${cluster.name}";\n""") cluster.childNodes.foreach { node => - subgraph.append(indent + s" ${makeDotNode(node)};\n") + subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } cluster.childClusters.foreach { cscope => - subgraph.append(makeDotSubgraph(cscope, indent + " ")) + makeDotSubgraph(subgraph, cscope, indent + " ") } - subgraph.append(indent + "}\n") - subgraph.toString() + subgraph.append(indent).append("}\n") } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 36943978ff59..fd6cc3ed759b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -17,12 +17,13 @@ package org.apache.spark.ui.storage +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.status.api.v1.{AllRDDResource, RDDDataDistribution, RDDPartitionInfo} -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils, WebUIPage} import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ @@ -32,6 +33,17 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterBlockPage = request.getParameter("block.page") + val parameterBlockSortColumn = request.getParameter("block.sort") + val parameterBlockSortDesc = request.getParameter("block.desc") + val parameterBlockPageSize = request.getParameter("block.pageSize") + + val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) + val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") + val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) + val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val rddId = parameterId.toInt val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) .getOrElse { @@ -44,8 +56,34 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) // Block table - val blockTable = UIUtils.listingTable(blockHeader, blockRow, rddStorageInfo.partitions.get, - id = Some("rdd-storage-by-block-table")) + val (blockTable, blockTableHTML) = try { + val _blockTable = new BlockPagedTable( + UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + rddStorageInfo.partitions.get, + blockPageSize, + blockSortColumn, + blockSortDesc) + (_blockTable, _blockTable.table(blockPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToBlockTable = + val content =
    @@ -85,11 +123,11 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
    -
    -
    -

    {rddStorageInfo.partitions.map(_.size).getOrElse(0)} Partitions

    - {blockTable} -
    +
    +

    + {rddStorageInfo.partitions.map(_.size).getOrElse(0)} Partitions +

    + {blockTableHTML ++ jsForScrollingDownToBlockTable}
    ; UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) @@ -101,14 +139,6 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { "Memory Usage", "Disk Usage") - /** Header fields for the block table */ - private def blockHeader = Seq( - "Block Name", - "Storage Level", - "Size in Memory", - "Size on Disk", - "Executors") - /** Render an HTML row representing a worker */ private def workerRow(worker: RDDDataDistribution): Seq[Node] = { @@ -120,23 +150,157 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {Utils.bytesToString(worker.diskUsed)} } +} + +private[ui] case class BlockTableRowData( + blockName: String, + storageLevel: String, + memoryUsed: Long, + diskUsed: Long, + executors: String) + +private[ui] class BlockDataSource( + rddPartitions: Seq[RDDPartitionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + + private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[BlockTableRowData] = { + data.slice(from, to) + } + + private def blockRow(rddPartition: RDDPartitionInfo): BlockTableRowData = { + BlockTableRowData( + rddPartition.blockName, + rddPartition.storageLevel, + rddPartition.memoryUsed, + rddPartition.diskUsed, + rddPartition.executors.mkString(" ")) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[BlockTableRowData] = { + val ordering = sortColumn match { + case "Block Name" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.blockName, y.blockName) + } + case "Storage Level" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.storageLevel, y.storageLevel) + } + case "Size in Memory" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.Long.compare(x.memoryUsed, y.memoryUsed) + } + case "Size on Disk" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.Long.compare(x.diskUsed, y.diskUsed) + } + case "Executors" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.executors, y.executors) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } +} + +private[ui] class BlockPagedTable( + basePath: String, + rddPartitions: Seq[RDDPartitionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[BlockTableRowData] { + + override def tableId: String = "rdd-storage-by-block-table" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: BlockDataSource = new BlockDataSource( + rddPartitions, + pageSize, + sortColumn, + desc) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" + + s"&block.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToBlockPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentBlockPageSize = ${pageSize} + |function goToBlockPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentBlockPageSize ? page : 1; + | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" + + | "&block.page=" + page + "&block.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } - /** Render an HTML row representing a block */ - private def blockRow(row: RDDPartitionInfo): Seq[Node] = { + override def headers: Seq[Node] = { + val blockHeaders = Seq( + "Block Name", + "Storage Level", + "Size in Memory", + "Size on Disk", + "Executors") + + if (!blockHeaders.contains(sortColumn)) { + throw new IllegalArgumentException(s"Unknown column: $sortColumn") + } + + val headerRow: Seq[Node] = { + blockHeaders.map { header => + if (header == sortColumn) { + val headerLink = + s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" + + s"&block.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } + } + {headerRow} + } + + override def row(block: BlockTableRowData): Seq[Node] = { - {row.blockName} - - {row.storageLevel} - - - {Utils.bytesToString(row.memoryUsed)} - - - {Utils.bytesToString(row.diskUsed)} - - - {row.executors.map(l => {l}
    )} - + {block.blockName} + {block.storageLevel} + {Utils.bytesToString(block.memoryUsed)} + {Utils.bytesToString(block.diskUsed)} + {block.executors} } } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 78e7ddc27d1c..1738258a0c79 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import scala.collection.JavaConversions.mapAsJavaMap +import scala.collection.JavaConverters._ import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging { val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig .getOrElse(ConfigFactory.empty()) - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 61b5a4cecddc..6c1fca71f228 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,8 +19,8 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.util.DynamicVariable -import com.google.common.annotations.VisibleForTesting import org.apache.spark.SparkContext /** @@ -61,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { + AsynchronousListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() self.synchronized { - processingEvent = false + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } } } } @@ -122,8 +124,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri * For testing only. Wait until there are no more events in the queue, or until the specified * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue * emptied. + * Exposed for testing. */ - @VisibleForTesting @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { val finishTime = System.currentTimeMillis + timeoutMillis @@ -140,8 +142,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Return whether the listener daemon thread is still alive. + * Exposed for testing. */ - @VisibleForTesting def listenerThreadIsAlive: Boolean = listenerThread.isAlive /** @@ -178,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri */ def onDropEvent(event: E): Unit } + +private[spark] object AsynchronousListenerBus { + /* Allows for Context to check whether stop() call is made within listener thread + */ + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} + diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala rename to core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index a4568e849fa1..8527e3ae692e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -15,27 +15,19 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.util +import java.io.ByteArrayOutputStream import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - -private[nio] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size: Int = if (buffer == null) 0 else buffer.remaining +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { - lazy val buffers: ArrayBuffer[ByteBuffer] = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } + def this() = this(32) - override def toString: String = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ebead830c646..e27d2e6c94f7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -21,8 +21,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -94,7 +94,7 @@ private[spark] object ClosureCleaner extends Logging { if (cls.isPrimitive) { cls match { case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Character.TYPE => new java.lang.Character('\u0000') case java.lang.Void.TYPE => // This should not happen because `Foo(void x) {}` does not compile. throw new IllegalStateException("Unexpected void parameter in constructor") @@ -181,7 +181,7 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClosureClasses(func) @@ -325,11 +325,11 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM4) { +private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name.contains("apply")) { - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { throw new ReturnStatementInClosureException @@ -337,7 +337,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { } } } else { - new MethodVisitor(ASM4) {} + new MethodVisitor(ASM5) {} } } } @@ -361,7 +361,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { override def visitMethod( access: Int, @@ -376,7 +376,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -385,7 +385,8 @@ private[util] class FieldAccessFinder( } } - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { // Check for calls a getter method for a variable in an interpreter wrapper object. // This means that the corresponding field will be accessed, so we should save it. @@ -408,7 +409,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -423,9 +424,9 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala index 17e55f7996bf..53934ad4ce47 100644 --- a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala +++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala @@ -22,10 +22,10 @@ import java.util.concurrent.atomic.AtomicInteger /** * A util used to get a unique generation ID. This is a wrapper around Java's * AtomicInteger. An example usage is in BlockManager, where each BlockManager - * instance would start an Akka actor and we use this utility to assign the Akka - * actors unique names. + * instance would start an RpcEndpoint and we use this utility to assign the RpcEndpoints' + * unique names. */ private[spark] class IdGenerator { - private var id = new AtomicInteger + private val id = new AtomicInteger def next: Int = id.incrementAndGet } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c600319d9ddb..cb0f1bf79f3d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -266,7 +271,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ @@ -282,7 +287,8 @@ private[spark] object JsonProtocol { ("ID" -> accumulableInfo.id) ~ ("Name" -> accumulableInfo.name) ~ ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~ - ("Value" -> accumulableInfo.value) + ("Value" -> accumulableInfo.value) ~ + ("Internal" -> accumulableInfo.internal) } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { @@ -362,8 +368,14 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) - case ExecutorLostFailure(executorId) => - ("Executor ID" -> executorId) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) + case ExecutorLostFailure(executorId, exitCausedByApp, reason) => + ("Executor ID" -> executorId) ~ + ("Exit Caused By App" -> exitCausedByApp) ~ + ("Loss Reason" -> reason.map(_.toString)) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -391,6 +403,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ + ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -498,6 +511,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } @@ -691,7 +706,8 @@ private[spark] object JsonProtocol { val name = (json \ "Name").extract[String] val update = Utils.jsonOption(json \ "Update").map(_.extract[String]) val value = (json \ "Value").extract[String] - AccumulableInfo(id, name, update, value) + val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) + AccumulableInfo(id, name, update, value, internal) } def taskMetricsFromJson(json: JValue): TaskMetrics = { @@ -769,6 +785,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -790,12 +807,25 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => + val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown")) + val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + ExecutorLostFailure( + executorId.getOrElse("Unknown"), + exitCausedByApp.getOrElse(true), + reason) case `unknownReason` => UnknownReason } } @@ -829,6 +859,7 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) + val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -841,7 +872,7 @@ private[spark] object JsonProtocol { .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope) + val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.externalBlockStoreSize = externalBlockStoreSize diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a725767d08cc..13cb516b583e 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,12 +19,11 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -46,7 +45,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * `postToAll` in the same thread for all events. */ final def postToAll(event: E): Unit = { - // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here ewe use // Java Iterator directly. val iter = listeners.iterator @@ -69,7 +68,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass - listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index 171855406198..e7a65d74a440 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) extends Clock { */ def waitTillTime(targetTime: Long): Long = synchronized { while (time < targetTime) { - wait(100) + wait(10) } getTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 169489df6c1e..945217203be7 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -21,7 +21,7 @@ import java.net.{URLClassLoader, URL} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. @@ -84,14 +84,9 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa } override def getResources(name: String): Enumeration[URL] = { - val urls = super.findResources(name) - val res = - if (urls != null && urls.hasMoreElements()) { - urls - } else { - parentClassLoader.getResources(name) - } - res + val childUrls = super.findResources(name).asScala + val parentUrls = parentClassLoader.getResources(name).asScala + (childUrls ++ parentUrls).asJavaEnumeration } override def addURL(url: URL) { diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala index e5c732a5a559..0b505a576768 100644 --- a/core/src/main/scala/org/apache/spark/util/NextIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -60,8 +60,10 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] { */ def closeIfNeeded() { if (!closed) { - close() + // Note: it's important that we set closed = true before calling close(), since setting it + // afterwards would permit us to call close() multiple times if close() threw an exception. closed = true + close() } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 000000000000..1a0f3b477ba3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + @volatile private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem] + .getField("SHUTDOWN_HOOK_PRIORITY") + .get(null) // static field, the value is not used + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + // scalastyle:on runtimeaddshutdownhook + } + } + + def runAll(): Unit = { + shuttingDown = true + var nextHook: SparkShutdownHook = null + while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) { + Try(Utils.logUncaughtExceptions(nextHook.run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = { + hooks.synchronized { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + } + + def remove(ref: AnyRef): Boolean = { + hooks.synchronized { hooks.remove(ref) } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 14b1f2a17e70..09864e3f8392 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import com.google.common.collect.MapMaker + import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} @@ -29,6 +31,20 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet +/** + * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation. + * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. + * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size + * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. + */ +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long +} /** * :: DeveloperApi :: @@ -73,7 +89,8 @@ object SizeEstimator extends Logging { private val ALIGN_SIZE = 8 // A cache of ClassInfo objects for each class - private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] + // We use weakKeys to allow GC of dynamically created classes + private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]() // Object and pointer sizes are arch dependent private var is64bit = false @@ -197,10 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57e..5e322557e964 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -29,11 +29,15 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3..f9fbe2ff858c 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -15,12 +15,12 @@ * limitations under the License. */ - package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -56,10 +56,18 @@ private[spark] object ThreadUtils { * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + def newDaemonCachedThreadPool( + prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = { val threadFactory = namedThreadFactory(prefix) - new ThreadPoolExecutor( - 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + val threadPool = new ThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool } /** @@ -80,10 +88,72 @@ private[spark] object ThreadUtils { } /** - * Wrapper over newSingleThreadScheduledExecutor. + * Wrapper over ScheduledThreadPoolExecutor. */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadScheduledExecutor(threadFactory) + val executor = new ScheduledThreadPoolExecutor(1, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 8de75ba9a9c9..d7e5143c3095 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,8 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{JavaConversions, mutable} +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.Logging @@ -50,8 +51,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + getEntrySet.iterator.asScala.map(kv => (kv.getKey, kv.getValue.value)) } def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet @@ -90,9 +90,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap) - .map { case (k, TimeStampedValue(v, t)) => (k, v) } - .filter(p) + internalMap.asScala.map { case (k, TimeStampedValue(v, t)) => (k, v) }.filter(p) } override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala index 7cd8f28b12dd..65efeb1f4c19 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions +import scala.collection.JavaConverters._ import scala.collection.mutable.Set private[spark] class TimeStampedHashSet[A] extends Set[A] { @@ -31,7 +31,7 @@ private[spark] class TimeStampedHashSet[A] extends Set[A] { def iterator: Iterator[A] = { val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) + jIterator.asScala.map(_.getKey) } override def + (elem: A): Set[A] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4012d0e83f7..fce89dfccfe2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,16 +21,17 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.nio.channels.Channels import java.util.concurrent._ +import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} +import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} @@ -42,7 +43,6 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ - import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} @@ -57,6 +57,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object CallSite { val SHORT_FORM = "callSite.short" val LONG_FORM = "callSite.long" + val empty = CallSite("", "") } /** @@ -65,21 +66,6 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. - */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 - /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -90,9 +76,6 @@ private[spark] object Utils extends Logging { @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -195,7 +178,7 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -205,84 +188,17 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) } - retval } /** @@ -333,7 +249,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -415,6 +331,30 @@ private[spark] object Utils extends Logging { } /** + * A file name may contain some invalid URI characters, such as " ". This method will convert the + * file name to a raw path accepted by `java.net.URI(String)`. + * + * Note: the file name must not contain "/" or "\" + */ + def encodeFileNameToURIRawPath(fileName: String): String = { + require(!fileName.contains("/") && !fileName.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + } + + /** + * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", + * return the name before the last "/". + */ + def decodeFileNameInURI(uri: URI): String = { + val rawPath = uri.getRawPath + val rawFileName = rawPath.split("/").last + new URI("file:///" + rawFileName).getPath.substring(1) + } + + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -435,7 +375,7 @@ private[spark] object Utils extends Logging { hadoopConf: Configuration, timestamp: Long, useCache: Boolean) { - val fileName = url.split("/").last + val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) if (useCache && fetchCacheEnabled) { @@ -633,6 +573,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { @@ -747,6 +695,7 @@ private[spark] object Utils extends Logging { * logic of locating the local directories according to deployment mode. */ def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { + val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -755,13 +704,23 @@ private[spark] object Utils extends Logging { getYarnLocalDirs(conf).split(",") } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) + } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { + conf.getenv("SPARK_LOCAL_DIRS").split(",") + } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + // Mesos already creates a directory per Mesos task. Spark should use that directory + // instead so all temporary files are automatically cleaned up when the Mesos task ends. + // Note that we don't want this if the shuffle service is enabled because we want to + // continue to serve shuffle files after the executors that wrote them have already exited. + Array(conf.getenv("MESOS_DIRECTORY")) } else { + if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { + logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + "spark.shuffle.service.enabled is enabled.") + } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user // configuration to point to a secure directory. So create a subdirectory with restricted // permissions under each listed directory. - Option(conf.getenv("SPARK_LOCAL_DIRS")) - .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - .split(",") + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(",") } } @@ -846,12 +805,12 @@ private[spark] object Utils extends Logging { // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order // on unix-like system. On windows, it returns in index order. // It's more proper to pick ip address following system output order. - val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse for (ni <- reOrderedNetworkIFs) { - val addresses = ni.getInetAddresses.toList - .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq if (addresses.nonEmpty) { val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) // because of Inet6Address.toHostName may add interface at the end if it knows about it @@ -973,9 +932,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1041,7 +998,7 @@ private[spark] object Utils extends Logging { } /** - * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If * no suffix is provided, the passed number is assumed to be in seconds. */ def timeStringAsSeconds(str: String): Long = { @@ -1466,7 +1423,7 @@ private[spark] object Utils extends Logging { file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) - logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") } stringBuffer.toString } @@ -1478,27 +1435,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -1619,10 +1555,8 @@ private[spark] object Utils extends Logging { * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ def getSystemProperties: Map[String, String] = { - val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield - (key, System.getProperty(key)) - - sysProps.toMap + System.getProperties.stringPropertyNames().asScala + .map(key => (key, System.getProperty(key))).toMap } /** @@ -1872,6 +1806,13 @@ private[spark] object Utils extends Logging { if (uri.getScheme() != null) { return uri } + // make sure to handle if the path has a fragment (applies to yarn + // distributed cache) + if (uri.getFragment() != null) { + val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI() + return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(), + uri.getFragment()) + } } catch { case e: URISyntaxException => } @@ -1933,7 +1874,8 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + properties.stringPropertyNames().asScala.map( + k => (k, properties.getProperty(k).trim)).toMap } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) @@ -2010,6 +1952,7 @@ private[spark] object Utils extends Logging { * This is expected to throw java.net.BindException on port collision. * @param conf A SparkConf used to get the maximum number of retries when binding to a port. * @param serviceName Name of the service. + * @return (service: T, port: Int) */ def startServiceOnPort[T]( startPort: Int, @@ -2062,7 +2005,8 @@ private[spark] object Utils extends Logging { return true } isBindCollision(e.getCause) - case e: MultiException => e.getThrowables.exists(isBindCollision) + case e: MultiException => + e.getThrowables.asScala.exists(isBindCollision) case e: Exception => isBindCollision(e.getCause) case _ => false } @@ -2221,37 +2165,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2286,70 +2199,28 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } -} - -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } - } - - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.getInt("spark.executor.instances", 0) == 0 } - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() } - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } - } - -} - -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority + /** + * Returns a path of temporary file which is in the same directory with `path`. + */ + def tempFileWith(path: File): File = { + new File(path.getAbsolutePath + "." + UUID.randomUUID()) } - - def run(): Unit = hook() - } /** diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index 2ed827eab46d..6b3fa8491904 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -122,6 +122,7 @@ class Vector(val elements: Array[Double]) extends Serializable { override def toString: String = elements.mkString("(", ", ", ")") } +@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") object Vector { def apply(elements: Array[Double]): Vector = new Vector(elements) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91..7ab67fc3a2de 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + /** + * Clear all set bits. + */ + def clear(): Unit = { + var i = 0 + while (i < numWords) { + words(i) = 0L + i += 1 + } + } + /** * Set all the bits up to a given index */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala deleted file mode 100644 index ae60f3b0cb55..000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - -/** - * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The - * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts - * of memory and needing to copy the full contents. The disadvantage is that the contents don't - * occupy a contiguous segment of memory. - */ -private[spark] class ChainedBuffer(chunkSize: Int) { - - private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( - java.lang.Long.highestOneBit(chunkSize)) - assert((1 << chunkSizeLog2) == chunkSize, - s"ChainedBuffer chunk size $chunkSize must be a power of two") - private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Long = 0 - - /** - * Feed bytes from this buffer into a DiskBlockObjectWriter. - * - * @param pos Offset in the buffer to read from. - * @param os OutputStream to read into. - * @param len Number of bytes to read. - */ - def read(pos: Long, os: OutputStream, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size ${_size} of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - os.write(chunks(chunkIndex), posInChunk, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Read bytes from this buffer into a byte array. - * - * @param pos Offset in the buffer to read from. - * @param bytes Byte array to read into. - * @param offs Offset in the byte array to read to. - * @param len Number of bytes to read. - */ - def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Write bytes from a byte array into this buffer. - * - * @param pos Offset in the buffer to write to. - * @param bytes Byte array to write from. - * @param offs Offset in the byte array to write from. - * @param len Number of bytes to write. - */ - def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos > _size) { - throw new IndexOutOfBoundsException( - s"Write at position $pos starts after end of buffer ${_size}") - } - // Grow if needed - val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt - while (endChunkIndex >= chunks.length) { - chunks += new Array[Byte](chunkSize) - } - - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toWrite: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) - written += toWrite - chunkIndex += 1 - posInChunk = 0 - } - - _size = math.max(_size, pos + len) - } - - /** - * Total size of buffer that can be written to without allocating additional memory. - */ - def capacity: Long = chunks.size.toLong * chunkSize - - /** - * Size of the logical buffer. - */ - def size: Long = _size -} - -/** - * Output stream that writes to a ChainedBuffer. - */ -private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos: Long = 0 - - override def write(b: Int): Unit = { - throw new UnsupportedOperationException() - } - - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - chainedBuffer.write(pos, bytes, offs, len) - pos += len - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d166037351c3..f6d81ee5bf05 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator import org.apache.spark.executor.ShuffleWriteMetrics @@ -48,16 +50,6 @@ import org.apache.spark.executor.ShuffleWriteMetrics * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk * writes. This may lead to a performance regression compared to the normal case of using the * non-spilling AppendOnlyMap. - * - * Two parameters control the memory threshold: - * - * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing - * these maps as a fraction of the executor's total memory. Since each concurrently running - * task maintains one map, the actual threshold for each map is this quantity divided by the - * number of running tasks. - * - * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of - * this threshold, in case map size estimation is not sufficiently accurate. */ @DeveloperApi class ExternalAppendOnlyMap[K, V, C]( @@ -65,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, - blockManager: BlockManager = SparkEnv.get.blockManager) + blockManager: BlockManager = SparkEnv.get.blockManager, + context: TaskContext = TaskContext.get()) extends Iterable[(K, C)] with Serializable with Logging with Spillable[SizeTracker] { + if (context == null) { + throw new IllegalStateException( + "Spillable collections should not be instantiated outside of tasks") + } + + // Backwards-compatibility constructor for binary compatibility + def this( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer, + blockManager: BlockManager) { + this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) + } + + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -89,6 +99,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = @@ -97,9 +108,19 @@ class ExternalAppendOnlyMap[K, V, C]( // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + /** + * Number of files this map has spilled so far. + * Exposed for testing. + */ + private[collection] def numSpills: Int = spilledMaps.size + /** * Insert the given key and value into the map. */ @@ -117,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C]( * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. */ def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + if (currentMap == null) { + throw new IllegalStateException( + "Cannot insert new elements into a map after calling iterator") + } // An update function for the map that we reuse across entries to avoid allocating // a new closure each time var curEntry: Product2[K, V] = null @@ -126,7 +151,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (maybeSpill(currentMap, currentMap.estimateSize())) { + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) @@ -199,7 +228,9 @@ class ExternalAppendOnlyMap[K, V, C]( writer.revertPartialWritesAndClose() } if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } } } } @@ -207,20 +238,27 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - def diskBytesSpilled: Long = _diskBytesSpilled - /** - * Return an iterator that merges the in-memory map with the spilled maps. + * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { + if (currentMap == null) { + throw new IllegalStateException( + "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") + } if (spilledMaps.isEmpty) { - currentMap.iterator + CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) } else { new ExternalIterator() } } + private def freeCurrentMap(): Unit = { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } + /** * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps */ @@ -232,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( + currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -482,16 +521,13 @@ class ExternalAppendOnlyMap[K, V, C]( fileStream = null } if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } } } - val context = TaskContext.get() - // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in - // a TaskContext. - if (context != null) { - context.addTaskCompletionListener(context => cleanup()) - } + context.addTaskCompletionListener(context => cleanup()) } /** Convenience function to hash the given (K, C) pair by the key. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ba7ec834d622..44b1d90667e6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,13 +23,12 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -68,33 +67,35 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we - * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. - * To avoid calling the partitioner multiple times with each key, we store the partition ID - * alongside each record. + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * - * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first - * by partition ID and possibly second by key or by hash code of the key, if we want to do - * aggregation. For each file, we track how many objects were in each partition in memory, so we - * don't have to write out the partition ID for every element. + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. * - * - When the user requests an iterator or file output, the spilled files are merged, along with - * any remaining in-memory data, using the same sort order defined above (unless both sorting - * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering - * from the ordering parameter, or read the keys with the same hash code and compare them with - * each other for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * - * - Users are expected to call stop() at the end to delete all the intermediate files. + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( + context: TaskContext, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] - with SortShuffleFileWriter[K, V] { + with Spillable[WritablePartitionedPairCollection[K, C]] { + + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() private val conf = SparkEnv.get.conf @@ -104,20 +105,11 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } - // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. - // As a sanity check, make sure that we're not handling a shuffle which should use that path. - if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { - throw new IllegalArgumentException("ExternalSorter should not be used to handle " - + " a sort that the BypassMergeSortShuffleWriter should handle") - } - private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -130,28 +122,19 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private val useSerializedPairBuffer = - ordering.isEmpty && - conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB - private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { - if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } - } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = newBuffer() + private var buffer = new PartitionedPairBuffer[K, C] // Total spilling statistics private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -185,7 +168,13 @@ private[spark] class ExternalSorter[K, V, C]( private val spills = new ArrayBuffer[SpilledFile] - override def insertAll(records: Iterator[Product2[K, V]]): Unit = { + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + + def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -220,19 +209,22 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - + var estimatedSize = 0L if (usingMap) { - if (maybeSpill(map, map.estimateSize())) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { map = new PartitionedAppendOnlyMap[K, C] } } else { - if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = newBuffer() + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { + buffer = new PartitionedPairBuffer[K, C] } } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } } /** @@ -281,6 +273,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -306,7 +300,9 @@ private[spark] class ExternalSorter[K, V, C]( writer.revertPartialWritesAndClose() } if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } } } } @@ -611,8 +607,8 @@ private[spark] class ExternalSorter[K, V, C]( * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. + * Exposed for testing. */ - @VisibleForTesting def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer @@ -642,12 +638,10 @@ private[spark] class ExternalSorter[K, V, C]( * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - override def writePartitionedFile( + def writePartitionedFile( blockId: BlockId, - context: TaskContext, outputFile: File): Array[Long] = { // Track location of each range in the output file @@ -684,15 +678,20 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) lengths } def stop(): Unit = { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() + releaseMemory() } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala deleted file mode 100644 index 87a786b02d65..000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.InputStream -import java.nio.IntBuffer -import java.util.Comparator - -import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.DiskBlockObjectWriter -import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ - -/** - * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes - * its records upon insert and stores them as raw bytes. - * - * We use two data-structures to store the contents. The serialized records are stored in a - * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a - * metadata buffer that stores pointers into the data buffer as well as the partition ID of each - * record. Each entry in the metadata buffer takes up a fixed amount of space. - * - * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not - * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can - * happen without following any pointers, which should minimize cache misses. - * - * Currently, only sorting by partition is supported. - * - * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across - * two integers: - * - * +-------------+------------+------------+-------------+ - * | keyStart | keyValLen | partitionId | - * +-------------+------------+------------+-------------+ - * - * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. - * - * @param metaInitialRecords The initial number of entries in the metadata buffer. - * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. - * @param serializerInstance the serializer used for serializing inserted records. - */ -private[spark] class PartitionedSerializedPairBuffer[K, V]( - metaInitialRecords: Int, - kvBlockSize: Int, - serializerInstance: SerializerInstance) - extends WritablePartitionedPairCollection[K, V] with SizeTracker { - - if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { - throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + - " Java-serialized objects.") - } - - require(metaInitialRecords <= MAXIMUM_RECORDS, - s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") - private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) - - private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) - private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) - private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) - - def insert(partition: Int, key: K, value: V): Unit = { - if (metaBuffer.position == metaBuffer.capacity) { - growMetaBuffer() - } - - val keyStart = kvBuffer.size - kvSerializationStream.writeKey[Any](key) - kvSerializationStream.writeValue[Any](value) - kvSerializationStream.flush() - val keyValLen = (kvBuffer.size - keyStart).toInt - - // keyStart, a long, gets split across two ints - metaBuffer.put(keyStart.toInt) - metaBuffer.put((keyStart >> 32).toInt) - metaBuffer.put(keyValLen) - metaBuffer.put(partition) - } - - /** Double the size of the array because we've reached capacity */ - private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { - throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") - } - val newCapacity = - if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { - // Overflow - MAXIMUM_META_BUFFER_CAPACITY - } else { - metaBuffer.capacity * 2 - } - val newMetaBuffer = IntBuffer.allocate(newCapacity) - newMetaBuffer.put(metaBuffer.array) - metaBuffer = newMetaBuffer - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) - : Iterator[((Int, K), V)] = { - sort(keyComparator) - val is = orderedInputStream - val deserStream = serializerInstance.deserializeStream(is) - new Iterator[((Int, K), V)] { - var metaBufferPos = 0 - def hasNext: Boolean = metaBufferPos < metaBuffer.position - def next(): ((Int, K), V) = { - val key = deserStream.readKey[Any]().asInstanceOf[K] - val value = deserStream.readValue[Any]().asInstanceOf[V] - val partition = metaBuffer.get(metaBufferPos + PARTITION) - metaBufferPos += RECORD_SIZE - ((partition, key), value) - } - } - } - - override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity - - override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) - : WritablePartitionedIterator = { - sort(keyComparator) - new WritablePartitionedIterator { - // current position in the meta buffer in ints - var pos = 0 - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - val keyStart = getKeyStartPos(metaBuffer, pos) - val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) - pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, keyValLen) - writer.recordWritten() - } - def nextPartition(): Int = metaBuffer.get(pos + PARTITION) - def hasNext(): Boolean = pos < metaBuffer.position - } - } - - // Visible for testing - def orderedInputStream: OrderedInputStream = { - new OrderedInputStream(metaBuffer, kvBuffer) - } - - private def sort(keyComparator: Option[Comparator[K]]): Unit = { - val comparator = if (keyComparator.isEmpty) { - new Comparator[Int]() { - def compare(partition1: Int, partition2: Int): Int = { - partition1 - partition2 - } - } - } else { - throw new UnsupportedOperationException() - } - - val sorter = new Sorter(new SerializedSortDataFormat) - sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) - } -} - -private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) - extends InputStream { - - import PartitionedSerializedPairBuffer._ - - private var metaBufferPos = 0 - private var kvBufferPos = - if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 - - override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) - - override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { - if (metaBufferPos >= metaBuffer.position) { - return -1 - } - val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - - (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt - val toRead = math.min(bytesRemainingInRecord, len) - kvBuffer.read(kvBufferPos, bytes, offs, toRead) - if (toRead == bytesRemainingInRecord) { - metaBufferPos += RECORD_SIZE - if (metaBufferPos < metaBuffer.position) { - kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) - } - } else { - kvBufferPos += toRead - } - toRead - } - - override def read(): Int = { - throw new UnsupportedOperationException() - } -} - -private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { - - private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) - - /** Return the sort key for the element at the given index. */ - override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { - metaBuffer.get(pos * RECORD_SIZE + PARTITION) - } - - /** Swap two elements. */ - override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { - val iOff = pos0 * RECORD_SIZE - val jOff = pos1 * RECORD_SIZE - System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) - System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) - System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) - } - - /** Copy a single element from src(srcPos) to dst(dstPos). */ - override def copyElement( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) - } - - /** - * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. - * Overlapping ranges are allowed. - */ - override def copyRange( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int, - length: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) - } - - /** - * Allocates a Buffer that can hold up to 'length' elements. - * All elements of the buffer should be considered invalid until data is explicitly copied in. - */ - override def allocate(length: Int): IntBuffer = { - IntBuffer.allocate(length * RECORD_SIZE) - } -} - -private object PartitionedSerializedPairBuffer { - val KEY_START = 0 // keyStart, a long, gets split across two ints - val KEY_VAL_LEN = 2 - val PARTITION = 3 - val RECORD_SIZE = PARTITION + 1 // num ints of metadata - - val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 - val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 - - def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { - val lower32 = metaBuffer.get(metaBufferPos + KEY_START) - val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) - (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a39..3a48af82b1da 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,8 @@ package org.apache.spark.util.collection -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} +import org.apache.spark.{Logging, SparkEnv} /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -40,13 +40,18 @@ private[spark] trait Spillable[C] extends Logging { protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + protected[this] def taskMemoryManager: TaskMemoryManager // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing + // For testing only private[this] val initialMemoryThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + // Force this collection to spill when there are this many elements in memory + // For testing only + private[this] val numElementsForceSpillThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold @@ -69,27 +74,28 @@ private[spark] trait Spillable[C] extends Logging { * @return true if `collection` was spilled to disk; false otherwise */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + var shouldSpill = false if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + val granted = + taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - _spillCount += 1 - logSpillage(currentMemory) - - spill(collection) - - _elementsRead = 0 - // Keep track of spills, and release memory - _memoryBytesSpilled += currentMemory - releaseMemoryForThisThread() - return true - } + // If we were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold), spill the current collection + shouldSpill = currentMemory >= myMemoryThreshold + } + shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold + // Actually spill + if (shouldSpill) { + _spillCount += 1 + logSpillage(currentMemory) + spill(collection) + _elementsRead = 0 + _memoryBytesSpilled += currentMemory + releaseMemory() } - false + shouldSpill } /** @@ -98,11 +104,12 @@ private[spark] trait Spillable[C] extends Logging { def memoryBytesSpilled: Long = _memoryBytesSpilled /** - * Release our memory back to the shuffle pool so that other threads can grab it. + * Release our memory back to the execution pool so that other tasks can grab it. */ - private def releaseMemoryForThisThread(): Unit = { + def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + taskMemoryManager.releaseExecutionMemory( + myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index bdbca00a0062..4939b600dbfb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import scala.collection.JavaConversions.{collectionAsScalaIterable, asJavaIterator} +import scala.collection.JavaConverters._ import com.google.common.collect.{Ordering => GuavaOrdering} @@ -34,6 +34,6 @@ private[spark] object Utils { val ordering = new GuavaOrdering[T] { override def compare(l: T, r: T): Int = ord.compare(l, r) } - collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator + ordering.leastOf(input.asJava, num).iterator.asScala } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 38848e9018c6..5232c2bd8d6f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that - * - Have an associated partition for each key-value pair. - * - Support a memory-efficient sorted iterator - * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. */ private[spark] trait WritablePartitionedPairCollection[K, V] { /** diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 7138b4b8e453..1e8476c4a047 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -79,32 +79,30 @@ private[spark] class RollingFileAppender( val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() val rolloverFile = new File( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile - try { - logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") - if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) - logInfo(s"Rolled over $activeFile to $rolloverFile") - } else { - // In case the rollover file name clashes, make a unique file name. - // The resultant file names are long and ugly, so this is used only - // if there is a name collision. This can be avoided by the using - // the right pattern such that name collisions do not occur. - var i = 0 - var altRolloverFile: File = null - do { - altRolloverFile = new File(activeFile.getParent, - s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile - i += 1 - } while (i < 10000 && altRolloverFile.exists) - - logWarning(s"Rollover file $rolloverFile already exists, " + - s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) - } + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + Files.move(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") } else { - logWarning(s"File $activeFile does not exist") + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + Files.move(activeFile, altRolloverFile) } + } else { + logWarning(s"File $activeFile does not exist") } } diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 786b97ad7b9e..c156b03cdb7c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * * @param fraction the sampling fraction (with replacement) + * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low. * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag]( + fraction: Double, + useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { + + def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true) /** Epsilon slop to avoid failure from floating point jitter. */ require( @@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { Iterator.empty - } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else if (useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) } else { - items.flatMap { item => { + items.flatMap { item => val count = rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) - }} + } } } - override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae6277..f98932a47016 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 85fb923cd9bc..e8cdb6e98bf3 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { private[spark] object XORShiftRandom { /** Hash seeds to have 0/1 bits throughout. */ - private def hashSeed(seed: Long): Long = { + private[random] def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() - MurmurHash3.bytesHash(bytes) + val lowBits = MurmurHash3.bytesHash(bytes) + val highBits = MurmurHash3.bytesHash(bytes, lowBits) + (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e948ca33471a..11f1248c24d3 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,10 +24,10 @@ import java.util.*; import java.util.concurrent.*; -import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import scala.collection.JavaConverters; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -51,7 +51,6 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; @@ -91,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -104,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -134,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -147,66 +146,76 @@ public void intersection() { public void sample() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD sample20 = rdd.sample(true, 0.2, 3); + // the seeds here are "magic" to make this work out nicely + JavaRDD sample20 = rdd.sample(true, 0.2, 8); Assert.assertEquals(2, sample20.count()); - JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5); + JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 2); Assert.assertEquals(2, sample20WithoutReplacement.count()); } @Test public void randomSplit() { - List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + List ints = new ArrayList<>(1000); + for (int i = 0; i < 1000; i++) { + ints.add(i); + } JavaRDD rdd = sc.parallelize(ints); JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); + // the splits aren't perfect -- not enough data for them to be -- just check they're about right Assert.assertEquals(3, splits.length); - Assert.assertEquals(1, splits[0].count()); - Assert.assertEquals(2, splits[1].count()); - Assert.assertEquals(7, splits[2].count()); + long s0 = splits[0].count(); + long s1 = splits[1].count(); + long s2 = splits[2].count(); + Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); + Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -215,10 +224,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -229,35 +238,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -266,7 +277,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -279,7 +290,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -302,7 +313,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -318,10 +329,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -391,18 +402,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -412,23 +422,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -440,27 +449,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -472,16 +480,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -549,11 +557,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -571,20 +579,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -603,11 +611,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -691,7 +699,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -744,6 +752,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -767,14 +776,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -810,7 +819,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -845,7 +854,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -871,26 +880,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -898,36 +906,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -954,7 +962,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -965,6 +973,19 @@ public Iterator call(Integer index, Iterator iter) throws Exce Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void getNumPartitions(){ + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); + JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), 2); + Assert.assertEquals(3, rdd1.getNumPartitions()); + Assert.assertEquals(2, rdd2.getNumPartitions()); + Assert.assertEquals(2, rdd3.getNumPartitions()); + } @Test public void repartition() { @@ -973,8 +994,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -983,7 +1004,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -995,9 +1016,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1011,7 +1032,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); + TaskContext context = TaskContext$.MODULE$.empty(); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @@ -1047,7 +1068,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1076,16 +1097,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1094,7 +1115,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1111,7 +1132,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1132,14 +1153,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1163,7 +1184,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1181,24 +1202,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1211,24 +1231,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1252,9 +1271,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1268,23 +1287,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1297,16 +1315,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1314,8 +1332,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1415,8 +1432,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1449,20 +1466,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1474,7 +1491,9 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) .combineByKey( createCombinerFunction, @@ -1495,21 +1514,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1522,7 +1541,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1533,23 +1552,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1560,15 +1579,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1586,7 +1605,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1597,7 +1616,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1614,7 +1633,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1622,12 +1641,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1640,7 +1659,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1648,25 +1667,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1678,7 +1697,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1715,7 +1734,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1744,7 +1763,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java similarity index 51% rename from launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java rename to core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca..aa15e792e2b2 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,12 +20,14 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.bridge.SLF4JBridgeHandler; import static org.junit.Assert.*; /** @@ -33,10 +35,62 @@ */ public class SparkLauncherSuite { + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d"); + + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,14 +98,18 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.name=-testChildProcLauncher") + "-Dfoo=bar -Dtest.appender=childproc") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); - new Redirector("stdout", app.getInputStream()).start(); - new Redirector("stderr", app.getErrorStream()).start(); + + new OutputRedirector(app.getInputStream(), TF); + new OutputRedirector(app.getErrorStream(), TF); assertEquals(0, app.waitFor()); } @@ -66,29 +124,4 @@ public static void main(String[] args) throws Exception { } - private static class Redirector extends Thread { - - private final InputStream in; - - Redirector(String name, InputStream in) { - this.in = in; - setName(name); - setDaemon(true); - } - - @Override - public void run() { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - String line; - while ((line = reader.readLine()) != null) { - LOG.warn(line); - } - } catch (Exception e) { - LOG.error("Error reading process output.", e); - } - } - - } - } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java new file mode 100644 index 000000000000..776a2997cf91 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MAX_VALUE, + Long.MAX_VALUE, + 1), + 0); + manager.allocatePage(4096, null); // leak memory + Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void encodePageNumberAndOffsetOffHeap() { + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + final MemoryBlock dataPage = manager.allocatePage(256, null); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryBlock dataPage = manager.allocatePage(256, null); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void cooperativeSpilling() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(100); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0); + + TestMemoryConsumer c1 = new TestMemoryConsumer(manager); + TestMemoryConsumer c2 = new TestMemoryConsumer(manager); + c1.use(100); + assert(c1.getUsed() == 100); + c2.use(100); + assert(c2.getUsed() == 100); + assert(c1.getUsed() == 0); // spilled + c1.use(100); + assert(c1.getUsed() == 100); + assert(c2.getUsed() == 0); // spilled + + c1.use(50); + assert(c1.getUsed() == 50); // spilled + assert(c2.getUsed() == 0); + c2.use(50); + assert(c1.getUsed() == 50); + assert(c2.getUsed() == 50); + + c1.use(100); + assert(c1.getUsed() == 100); + assert(c2.getUsed() == 0); // spilled + + c1.free(20); + assert(c1.getUsed() == 80); + c2.use(10); + assert(c1.getUsed() == 80); + assert(c2.getUsed() == 10); + c2.use(100); + assert(c2.getUsed() == 100); + assert(c1.getUsed() == 0); // spilled + + c1.free(0); + c2.free(100); + assert(manager.cleanUpAllAllocatedMemory() == 0); + } + + @Test + public void offHeapConfigurationBackwardsCompatibility() { + // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which + // was deprecated in Spark 1.6 and replaced by `spark.memory.offHeap.enabled` (see SPARK-12251). + final SparkConf conf = new SparkConf() + .set("spark.unsafe.offHeap", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + assert(manager.tungstenMemoryMode == MemoryMode.OFF_HEAP); + } + +} diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java new file mode 100644 index 000000000000..e6e16fff8040 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager) { + super(memoryManager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + long used = getUsed(); + free(used); + return used; + } + + void use(long size) { + long got = taskMemoryManager.acquireExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); + used += got; + } + + void free(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); + } +} + + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala new file mode 100644 index 000000000000..0b19861fc41e --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{File, FileInputStream, FileOutputStream} + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.{Mock, MockitoAnnotations} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite} + + +class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + + private var tempDir: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + MockitoAnnotations.initMocks(this) + + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + new File(tempDir, invocation.getArguments.head.toString) + } + }) + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + + test("commit shuffle files multiple times") { + val lengths = Array[Long](10, 0, 20) + val resolver = new IndexShuffleBlockResolver(conf, blockManager) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + out.write(new Array[Byte](30)) + out.close() + resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + + val dataFile = resolver.getDataFile(1, 2) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp.exists()) + + val dataTmp2 = File.createTempFile("shuffle", null, tempDir) + val out2 = new FileOutputStream(dataTmp2) + val lengths2 = new Array[Long](3) + out2.write(Array[Byte](1)) + out2.write(new Array[Byte](29)) + out2.close() + resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + assert(lengths2.toSeq === lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val in = new FileInputStream(dataFile) + val firstByte = new Array[Byte](1) + in.read(firstByte) + assert(firstByte(0) === 0) + + // remove data file + dataFile.delete() + + val dataTmp3 = File.createTempFile("shuffle", null, tempDir) + val out3 = new FileOutputStream(dataTmp3) + val lengths3 = Array[Long](10, 10, 15) + out3.write(Array[Byte](2)) + out3.write(new Array[Byte](34)) + out3.close() + resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + assert(lengths3.toSeq != lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 35) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val in2 = new FileInputStream(dataFile) + val firstByte2 = new Array[Byte](1) + in2.read(firstByte2) + assert(firstByte2(0) === 2) + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java similarity index 76% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index db9e82759090..fe5abc5c2304 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -15,25 +15,31 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; + +import java.io.IOException; import org.junit.Test; -import static org.junit.Assert.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; + +import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; +import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; public class PackedRecordPointerSuite { @Test - public void heap() { + public void heap() throws IOException { + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + new TaskMemoryManager(new TestMemoryManager(conf), 0); + final MemoryBlock page0 = memoryManager.allocatePage(128, null); + final MemoryBlock page1 = memoryManager.allocatePage(128, null); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); @@ -47,11 +53,14 @@ public void heap() { } @Test - public void offHeap() { + public void offHeap() throws IOException { + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + new TaskMemoryManager(new TestMemoryManager(conf), 0); + final MemoryBlock page0 = memoryManager.allocatePage(128, null); + final MemoryBlock page1 = memoryManager.allocatePage(128, null); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java similarity index 70% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 8fa72597db24..0328e63e4543 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Arrays; import java.util.Random; @@ -24,28 +24,30 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -public class UnsafeShuffleInMemorySorterSuite { +public class ShuffleInMemorySorterSuite { + + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @Test public void testSortingEmptyInput() { - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -62,11 +64,12 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); + new TaskMemoryManager(new TestMemoryManager(conf), 0); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -74,20 +77,16 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } // Sort the records - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int prevPartitionId = -1; Arrays.sort(dataToSort); for (int i = 0; i < dataToSort.length; i++) { @@ -98,7 +97,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), @@ -111,7 +110,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { @@ -120,7 +119,7 @@ public void testSortingManyNumbers() throws Exception { } Arrays.sort(numbersToSort); int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int j = 0; while (iter.hasNext()) { iter.loadNext(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java similarity index 81% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 04fc09b323db..5fe64bde3604 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.*; import java.nio.ByteBuffer; @@ -23,7 +23,6 @@ import scala.*; import scala.collection.Iterator; -import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.Iterators; @@ -40,7 +39,6 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -55,18 +53,16 @@ import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + TestMemoryManager memoryManager; + TaskMemoryManager taskMemoryManager; final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); File mergedOutputFile; File tempDir; @@ -76,7 +72,6 @@ public class UnsafeShuffleWriterSuite { final Serializer serializer = new KryoSerializer(new SparkConf()); TaskMetrics taskMetrics; - @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @@ -111,10 +106,12 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf().set("spark.buffer.pageSize", "128m"); + conf = new SparkConf() + .set("spark.buffer.pageSize", "1m") + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); - - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + memoryManager = new TestMemoryManager(conf); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -128,13 +125,13 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( - (BlockId) args[0], (File) args[1], (SerializerInstance) args[2], (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); @@ -173,9 +170,13 @@ public OutputStream answer(InvocationOnMock invocation) throws Throwable { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + File tmp = (File) invocationOnMock.getArguments()[3]; + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); return null; } - }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + }).when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( new Answer>() { @@ -190,6 +191,7 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -202,8 +204,7 @@ private UnsafeShuffleWriter createWriter( blockManager, shuffleBlockResolver, taskMemoryManager, - shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, conf @@ -350,9 +351,7 @@ private void testMergingSpills( } assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); - assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); @@ -404,19 +403,14 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { @Test public void writeEnoughDataToTriggerSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to allocate new data page - .then(returnsFirstArg()); // Grant new sort buffer and data page. + memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); - final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; - for (int i = 0; i < 128 + 1; i++) { + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10]; + for (int i = 0; i < 10 + 1; i++) { dataToWrite.add(new Tuple2(i, bigByteArray)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -431,18 +425,13 @@ public void writeEnoughDataToTriggerSpill() throws Exception { @Test public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to grow sort buffer - .then(returnsFirstArg()); // Grant new sort buffer and data page. + memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = new ArrayList>(); - for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -460,7 +449,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); - final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); @@ -473,62 +462,22 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception @Test public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { - // Use a custom serializer so that we have exact control over the size of serialized data. - final Serializer byteArraySerializer = new Serializer() { - @Override - public SerializerInstance newInstance() { - return new SerializerInstance() { - @Override - public SerializationStream serializeStream(final OutputStream s) { - return new SerializationStream() { - @Override - public void flush() { } - - @Override - public SerializationStream writeObject(T t, ClassTag ev1) { - byte[] bytes = (byte[]) t; - try { - s.write(bytes); - } catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } - - @Override - public void close() { } - }; - } - public ByteBuffer serialize(T t, ClassTag ev1) { return null; } - public DeserializationStream deserializeStream(InputStream s) { return null; } - public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } - public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } - }; - } - }; - when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); - writer.forceSorterToSpill(); + final ArrayList> dataToWrite = new ArrayList>(); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; + final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4]; new Random(42).nextBytes(atMaxRecordSize); - writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); - writer.forceSorterToSpill(); - // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; + dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size + final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()]; new Random(42).nextBytes(exceedsMaxRecordSize); - Product2 hugeRecord = - new Tuple2(new byte[0], exceedsMaxRecordSize); - try { - // Here, we write through the public `write()` interface instead of the test-only - // `insertRecordIntoSorter` interface: - writer.write(Collections.singletonList(hugeRecord).iterator()); - fail("Expected exception to be thrown"); - } catch (IOException e) { - // Pass - } + dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -542,4 +491,58 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + taskMemoryManager = spy(taskMemoryManager); + when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + new SerializedShuffleHandle<>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // The first page is allocated in constructor, another page will be allocated after + // every numRecordsPerPage records (peak memory should change). + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dbb7c662d787..702ba5469b8b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -17,75 +17,135 @@ package org.apache.spark.unsafe.map; -import java.lang.Exception; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; -import org.junit.*; +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.hamcrest.Matchers.greaterThan; -import static org.mockito.AdditionalMatchers.geq; -import static org.mockito.Mockito.*; -import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.SparkConf; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.util.Utils; + +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.when; public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); - private ShuffleMemoryManager shuffleMemoryManager; + private TestMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; - private TaskMemoryManager sizeLimitedTaskMemoryManager; private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + final LinkedList spillFilesCreated = new LinkedList(); + File tempDir; + + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + @Before public void setup() { - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); - // Mocked memory manager for tests that check the maximum array size, since actually allocating - // such large arrays will cause us to run out of memory in our tests. - sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class); - when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer( - new Answer() { - @Override - public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { - if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { - throw new OutOfMemoryError("Requested array size exceeds VM limit"); - } - return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]); - } + memoryManager = + new TestMemoryManager( + new SparkConf() + .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeap.size", "256mb")); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); + spillFilesCreated.clear(); + MockitoAnnotations.initMocks(this); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); } - ); + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); } @After public void tearDown() { + Utils.deleteRecursively(tempDir); + tempDir = null; + Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - Assert.assertEquals(0L, leakedShuffleMemory); + if (taskMemoryManager != null) { + long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask(); + taskMemoryManager = null; + Assert.assertEquals(0L, leakedMemory); } } - protected abstract MemoryAllocator getMemoryAllocator(); + protected abstract boolean useOffHeapMemoryAllocator(); private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -107,7 +167,7 @@ private static boolean arrayEquals( long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), actualAddr.getBaseOffset(), expected.length @@ -116,14 +176,13 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.numElements()); final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); @@ -132,22 +191,21 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); final byte[] valueData = getRandomByteArray(recordLengthWords); try { final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); // After storing the key and value, the other location methods should return results that @@ -158,7 +216,8 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); @@ -167,10 +226,10 @@ public void setAndRetrieveAKey() { try { Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); Assert.fail("Should not be able to set a new value for a key"); @@ -182,56 +241,71 @@ public void setAndRetrieveAKey() { } } - @Test - public void iteratorTest() throws Exception { + private void iteratorTestBase(boolean destructive) throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( null, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 0, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); - final Iterator iter = map.iterator(); + final Iterator iter; + if (destructive) { + iter = map.destructiveIterator(); + } else { + iter = map.iterator(); + } + int numPages = map.getNumDataPages(); + int countFreedPages = 0; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( + final long value = Platform.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); + if (destructive) { + // The iterator moves onto next page and frees previous page + if (map.getNumDataPages() < numPages) { + numPages = map.getNumDataPages(); + countFreedPages++; + } + } + } + if (destructive) { + // Latest page is not freed by iterator but by map itself + Assert.assertEquals(countFreedPages, numPages - 1); } Assert.assertEquals(size, valuesSeen.cardinality()); } finally { @@ -239,13 +313,23 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratorTest() throws Exception { + iteratorTestBase(false); + } + + @Test + public void destructiveIteratorTest() throws Exception { + iteratorTestBase(true); + } + @Test public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 24; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); + final BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -256,16 +340,16 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH )); } @@ -273,25 +357,25 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); for (long j : key) { @@ -314,9 +398,7 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES); - + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { @@ -326,16 +408,16 @@ public void randomizedStressTest() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -349,9 +431,10 @@ public void randomizedStressTest() { } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -364,8 +447,7 @@ public void randomizedStressTest() { @Test public void randomizedTestWithRecordsLargerThanPageSize() { final long pageSizeBytes = 128; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes); // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); @@ -377,16 +459,16 @@ public void randomizedTestWithRecordsLargerThanPageSize() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -399,9 +481,10 @@ public void randomizedTestWithRecordsLargerThanPageSize() { } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -413,18 +496,15 @@ public void randomizedTestWithRecordsLargerThanPageSize() { @Test public void failureToAllocateFirstPage() { - shuffleMemoryManager = new ShuffleMemoryManager(1024); - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + memoryManager.limit(1024); // longArray + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = - map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); Assert.assertFalse(loc.putNewKey( - emptyArray, LONG_ARRAY_OFFSET, 0, - emptyArray, LONG_ARRAY_OFFSET, 0 - )); + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); } @@ -433,15 +513,18 @@ public void failureToAllocateFirstPage() { @Test public void failureToGrow() { - shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10); - BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024); try { boolean success = true; int i; - for (i = 0; i < 1024; i++) { + for (i = 0; i < 127; i++) { + if (i > 0) { + memoryManager.limit(0); + } final long[] arr = new long[]{i}; - final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); - success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -453,10 +536,48 @@ public void failureToGrow() { } } + @Test + public void spillInIterator() throws IOException { + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false); + try { + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + } + BytesToBytesMap.MapIterator iter = map.iterator(); + for (i = 0; i < 100; i++) { + iter.next(); + } + // Non-destructive iterator is not spillable + Assert.assertEquals(0, iter.spill(1024L * 10)); + for (i = 100; i < 1024; i++) { + iter.next(); + } + + BytesToBytesMap.MapIterator iter2 = map.destructiveIterator(); + for (i = 0; i < 100; i++) { + iter2.next(); + } + Assert.assertTrue(iter2.spill(1024) >= 1024); + for (i = 100; i < 1024; i++) { + iter2.next(); + } + assertFalse(iter2.hasNext()); + } finally { + map.free(); + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + } + @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES); + new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -464,35 +585,56 @@ public void initialCapacityBoundsChecking() { try { new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, + taskMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } - - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - // Can allocate _at_ the max capacity - // BytesToBytesMap map = new BytesToBytesMap( - // sizeLimitedTaskMemoryManager, - // shuffleMemoryManager, - // BytesToBytesMap.MAX_CAPACITY, - // PAGE_SIZE_BYTES); - // map.free(); } - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - @Ignore - public void resizingLargeMap() { - // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, - BytesToBytesMap.MAX_CAPACITY - 64, - PAGE_SIZE_BYTES); - map.growAndRehash(); - map.free(); + @Test + public void testPeakMemoryUsed() { + final long recordLengthBytes = 24; + final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker + final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes); + + // Since BytesToBytesMap is append-only, we expect the total memory consumption to be + // monotonically increasing. More specifically, every time we allocate a new page it + // should increase by exactly the size of the page. In this regard, the memory usage + // at any given time is also the peak memory used. + long previousPeakMemory = map.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (long i = 0; i < numRecordsPerPage * 10; i++) { + final long[] value = new long[]{i}; + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( + value, + Platform.LONG_ARRAY_OFFSET, + 8, + value, + Platform.LONG_ARRAY_OFFSET, + 8); + newPeakMemory = map.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Freeing the map should not change the peak memory + map.free(); + newPeakMemory = map.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + } finally { + map.free(); + } } + } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java index 5a10de49f54f..f0bad4d760c1 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.UNSAFE; + protected boolean useOffHeapMemoryAllocator() { + return true; } - } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java index 12cc9b25d93b..d76bb4fd05c5 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.HEAP; + protected boolean useOffHeapMemoryAllocator() { + return false; } - } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 52fa8bcd57e7..e0ee281e98b7 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -18,8 +18,10 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; @@ -34,29 +36,30 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsSecondArg; -import static org.mockito.Answers.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.*; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -77,14 +80,14 @@ public int compare( } }; - ShuffleMemoryManager shuffleMemoryManager; + SparkConf sparkConf; + File tempDir; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; - File tempDir; - private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m"); + private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); private static final class CompressStream extends AbstractFunction1 { @Override @@ -96,8 +99,8 @@ public OutputStream apply(OutputStream stream) { @Before public void setUp() { MockitoAnnotations.initMocks(this); - tempDir = new File(Utils.createTempDir$default$1()); - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + sparkConf = new SparkConf(); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); @@ -122,13 +125,13 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( - (BlockId) args[0], (File) args[1], (SerializerInstance) args[2], (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); @@ -138,13 +141,12 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th @After public void tearDown() { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); + try { + assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); + } finally { + Utils.deleteRecursively(tempDir); + tempDir = null; } - assertEquals(0, leakedUnsafeMemory); } private void assertSpillFilesWereCleanedUp() { @@ -155,23 +157,31 @@ private void assertSpillFilesWereCleanedUp() { } private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { - final int[] arr = new int[] { value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + final int[] arr = new int[]{ value }; + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } - @Test - public void testSortingOnlyByPrefix() throws Exception { + private static void insertRecord( + UnsafeExternalSorter sorter, + int[] record, + long prefix) throws IOException { + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); + } - final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + private UnsafeExternalSorter newSorter() throws IOException { + return UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, /* initialSize */ 1024, pageSizeBytes); + } + @Test + public void testSortingOnlyByPrefix() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); insertNumber(sorter, 5); insertNumber(sorter, 1); insertNumber(sorter, 3); @@ -186,26 +196,16 @@ public void testSortingOnlyByPrefix() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - // TODO: read rest of value. + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); } - sorter.freeMemory(); + sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @Test public void testSortingEmptyArrays() throws Exception { - - final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( - taskMemoryManager, - shuffleMemoryManager, - blockManager, - taskContext, - recordComparator, - prefixComparator, - /* initialSize */ 1024, - pageSizeBytes); - + final UnsafeExternalSorter sorter = newSorter(); sorter.insertRecord(null, 0, 0, 0); sorter.insertRecord(null, 0, 0, 0); sorter.spill(); @@ -222,29 +222,198 @@ public void testSortingEmptyArrays() throws Exception { assertEquals(0, iter.getRecordLength()); } - sorter.freeMemory(); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillingOccursInResponseToMemoryPressure() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + // This should be enough records to completely fill up a data page: + final int numRecords = (int) (pageSizeBytes / (4 + 4)); + for (int i = 0; i < numRecords; i++) { + insertNumber(sorter, numRecords - i); + } + assertEquals(1, sorter.getNumberOfAllocatedPages()); + memoryManager.markExecutionAsOutOfMemoryOnce(); + // The insertion of this record should trigger a spill: + insertNumber(sorter, 0); + // Ensure that spill files were created + assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); + // Read back the sorted data: + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + int i = 0; + while (iter.hasNext()) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + i++; + } + assertEquals(numRecords + 1, i); + sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @Test public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + byte[] record = new byte[16]; + while (sorter.getNumberOfAllocatedPages() < 2) { + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void sortingRecordsThatExceedPageSize() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + final int[] largeRecord = new int[(int) pageSizeBytes + 16]; + Arrays.fill(largeRecord, 456); + final int[] smallRecord = new int[100]; + Arrays.fill(smallRecord, 123); + + insertRecord(sorter, largeRecord, 456); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + insertRecord(sorter, largeRecord, 456); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + + assertFalse(iter.hasNext()); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void forcedSpillingWithReadIterator() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + } + assert(sorter.getNumberOfAllocatedPages() >= 2); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + int lastv = 0; + for (int i = 0; i < n / 3; i++) { + iter.hasNext(); + iter.loadNext(); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + lastv = i; + } + assert(iter.spill() > 0); + assert(iter.spill() == 0); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv); + for (int i = n / 3; i < n; i++) { + iter.hasNext(); + iter.loadNext(); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void forcedSpillingWithNotReadIterator() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + } + assert(sorter.getNumberOfAllocatedPages() >= 2); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + assert(iter.spill() > 0); + assert(iter.spill() == 0); + for (int i = 0; i < n; i++) { + iter.hasNext(); + iter.loadNext(); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, - /* initialSize */ 1024, + 1024, pageSizeBytes); - byte[] record = new byte[16]; - while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = sorter.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + insertNumber(sorter, i); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + sorter.spill(); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + insertNumber(sorter, i); + } + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); } - sorter.freeMemory(); - assertSpillFilesWereCleanedUp(); } } + diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539..93efd033eb94 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -20,34 +20,36 @@ import java.util.Arrays; import org.junit.Test; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; -import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.isIn; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, length); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @Test public void testSortingEmptyInput() { - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -68,22 +70,19 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Lychee", "Mango" }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -107,13 +106,13 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 31ac9beea878..8f8067f86d57 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -64,4 +73,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index bff6a4f69d07..08b692eda802 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -20,4 +23,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 111cb8163eb3..b07011d4f113 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index ef339f89afa4..2f71520549e1 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 056fac708859..5b957ed54955 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -72,6 +81,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -86,4 +98,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 79ccacd30969..afa425f8c27b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -24,4 +27,4 @@ "name" : "my counter", "value" : "5050" } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 32d5731676ad..12665a152c9e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -239,4 +242,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index eb3b1999eb99..a54d27de91ed 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -16,13 +16,22 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +# Tests that launch java subprocesses can set the "test.appender" system property to +# "console" to avoid having the child process's logs overwrite the unit test's +# log file. +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e942d6579b2f..5b84acf40be4 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference import org.scalatest.Matchers +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.scheduler._ -class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { + import InternalAccumulator._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -155,4 +159,223 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } + test("internal accumulators in TaskContext") { + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) + val internalMetricsToAccums = taskContext.internalMetricsToAccumulators + val collectedInternalAccums = taskContext.collectInternalAccumulators() + val collectedAccums = taskContext.collectAccumulators() + assert(internalMetricsToAccums.size > 0) + assert(internalMetricsToAccums.values.forall(_.isInternal)) + assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) + val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) + assert(collectedInternalAccums.size === internalMetricsToAccums.size) + assert(collectedInternalAccums.size === collectedAccums.size) + assert(collectedInternalAccums.contains(testAccum.id)) + assert(collectedAccums.contains(testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + val rdd = sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + rdd.count() + } + + test("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + test("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findAccumulableInfo( + accums: Iterable[AccumulableInfo], + name: String): AccumulableInfo = { + accums.find { a => a.name == name }.getOrElse { + throw new TestFailedException(s"internal accumulator '$name' not found", 0) + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + +} + +private[spark] object AccumulatorSuite { + + /** + * Run one or more Spark jobs and verify that the peak execution memory accumulator + * is updated afterwards. + */ + def verifyPeakExecutionMemorySet( + sc: SparkContext, + testName: String)(testBody: => Unit): Unit = { + val listener = new SaveInfoListener + sc.addSparkListener(listener) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { jobId => + if (jobId == 0) { + // The first job is a dummy one to verify that the accumulator does not already exist + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + } else { + // In the subsequent jobs, verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } + } + // Run the jobs + sc.parallelize(1 to 10).count() + testBody + } +} + +/** + * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. + */ +private class SaveInfoListener extends SparkListener { + private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] + private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + + /** Register a callback to be called on job end. */ + def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + jobCompletionCallback = callback + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (jobCompletionCallback != null) { + jobCompletionCallback(jobEnd.jobId) + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + completedStageInfos += stageCompleted.stageInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + completedTaskInfos += taskEnd.taskInfo + } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 618a5fb24710..cb8bd04e496a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, TaskMetrics} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d343bb95cb68..553d46285ac0 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,231 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert(operatedRDD.getCheckpointFile.nonEmpty) + val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) + assert(collectFunc(recoveredRDD) === result) + assert(recoveredRDD.partitioner === operatedRDD.partitioner) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest( + name: String, + skipLocalCheckpoint: Boolean = false + )(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + if (!skipLocalCheckpoint) { + test(name + " [local checkpoint]")(body(false)) + } + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -46,6 +260,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -56,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(flatMappedRDD.collect() === result) } + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + + def testPartitionerCheckpointing( + partitioner: Partitioner, + corruptPartitionerFile: Boolean = false + ): Unit = { + val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) + rddWithPartitioner.checkpoint() + rddWithPartitioner.count() + assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, + "checkpointing was not successful") + + if (corruptPartitionerFile) { + // Overwrite the partitioner file with garbage data + val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) + val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) + val partitionerFile = fs.listStatus(checkpointDir) + .find(_.getPath.getName.contains("partitioner")) + .map(_.getPath) + require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") + val output = fs.create(partitionerFile.get, true) + output.write(100) + output.close() + } + + val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) + assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") + + if (!corruptPartitionerFile) { + assert(newRDD.partitioner != None, "partitioner not recovered") + assert(newRDD.partitioner === rddWithPartitioner.partitioner, + "recovered partitioner does not match") + } else { + assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") + } + } + + testPartitionerCheckpointing(partitioner) + + // Test that corrupted partitioner file does not prevent recovery of RDD + testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) + } + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) @@ -241,209 +500,15 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) + assert(rdd.isCheckpointedAndMaterialized === false) checkpoint(rdd, reliableCheckpoint) + assert(rdd.isCheckpointed === false) + assert(rdd.isCheckpointedAndMaterialized === false) assert(rdd.count() === 0) assert(rdd.isCheckpointed === true) + assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } - - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) - } - - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) - } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) - } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -483,12 +548,11 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 600c1403b034..1c3f2bc315dd 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -203,25 +203,35 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("compute without caching when no partitions fit in memory") { - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory - val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) + val size = 10000 + val conf = new SparkConf() + .set("spark.storage.unrollMemoryThreshold", "1024") + .set("spark.testing.memory", (size / 2).toString) + sc = new SparkContext(clusterUrl, "test", conf) + val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY) + assert(data.count() === size) + assert(data.count() === size) + assert(data.count() === size) + // ensure only a subset of partitions were cached + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}") } test("compute when only some partitions fit in memory") { - val conf = new SparkConf().set("spark.storage.memoryFraction", "0.01") + val size = 10000 + val numPartitions = 10 + val conf = new SparkConf() + .set("spark.storage.unrollMemoryThreshold", "1024") + .set("spark.testing.memory", (size * numPartitions).toString) sc = new SparkContext(clusterUrl, "test", conf) - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions - // to make sure that *some* of them do fit though - val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) + val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY) + assert(data.count() === size) + assert(data.count() === size) + assert(data.count() === size) + // ensure only a subset of partitions were cached + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + assert(rddBlocks.size > 0, "no RDD blocks found") + assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions") } test("passing environment variables to cluster") { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 34caca892891..fedfbd547b91 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -787,6 +787,108 @@ class ExecutorAllocationManagerSuite Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + + test("reset the state of allocation manager") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + + // Allocation manager is reset when adding executor requests are sent without reporting back + // executor added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 4) + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 5) + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + + // Allocation manager is reset when executors are added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + + // Allocation manager is reset when executors are pending to remove + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + removeExecutor(manager, "first") + removeExecutor(manager, "second") + assert(executorsPendingToRemove(manager) === Set("first", "second")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorsPendingToRemove(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index c38d70252add..1c775bcb3d9c 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,8 +35,8 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) - rpcHandler = new ExternalShuffleBlockHandler(transportConf) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) + rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 69cb4b44cf7e..203dab934ca1 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // cause is preserved val thrownDueToTaskFailure = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128, null) throw new Exception("intentional task failure") iter }.count() @@ -159,12 +159,97 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // If the task succeeded but memory was leaked, then the task should fail due to that leak val thrownDueToMemoryLeak = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128, null) iter }.count() } assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) } + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + + // Run a 3-task map stage where one task fails once. + test("failure in tasks in a submitMapStage") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + sc.submitMapStage(dep).get() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 418763f4e5ff..fdb00aafc4a4 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -506,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 139b8dc25f4b..3cd80c0f7d17 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,10 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} +import scala.collection.Map import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} @@ -96,18 +99,18 @@ class HeartbeatReceiverSuite test("normal heartbeat") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 2) assert(trackedExecutors.contains(executorId1)) assert(trackedExecutors.contains(executorId2)) } test("reregister if scheduler is not ready yet") { - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + addExecutorAndVerify(executorId1) // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister triggerHeartbeat(executorId1, executorShouldReregister = true) } @@ -116,20 +119,20 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) // Received heartbeat from unknown executor, so we ask it to re-register triggerHeartbeat(executorId1, executorShouldReregister = true) - assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty) + assert(getTrackedExecutors.isEmpty) } test("reregister if heartbeat from removed executor") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) // Remove the second executor but not the first - heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + removeExecutorAndVerify(executorId2) // Now trigger the heartbeats // A heartbeat from the second executor should require reregistering triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = true) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) @@ -138,8 +141,8 @@ class HeartbeatReceiverSuite test("expire dead hosts") { val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) // Advance the clock and only trigger a heartbeat for the first executor @@ -149,7 +152,7 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host verify(scheduler).executorLost(Matchers.eq(executorId2), any()) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) @@ -170,13 +173,13 @@ class HeartbeatReceiverSuite val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) @@ -222,6 +225,26 @@ class HeartbeatReceiverSuite } } + private def addExecutorAndVerify(executorId: String): Unit = { + assert( + heartbeatReceiver.addExecutor(executorId).map { f => + Await.result(f, 10.seconds) + } === Some(true)) + } + + private def removeExecutorAndVerify(executorId: String): Unit = { + assert( + heartbeatReceiver.removeExecutor(executorId).map { f => + Await.result(f, 10.seconds) + } === Some(true)) + } + + private def getTrackedExecutors: Map[String, Long] = { + // We may receive undesired SparkListenerExecutorAdded from LocalBackend, so exclude it from + // the map. See SPARK-10800. + heartbeatReceiver.invokePrivate(_executorLastSeen()). + filterKeys(_ != SparkContext.DRIVER_IDENTIFIER) + } } // TODO: use these classes to add end-to-end tests for dynamic allocation! diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index af4e68950f75..7e70308bb360 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -168,10 +168,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -198,10 +197,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) verify(rpcCallContext, never()).reply(any()) verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 33270bec6247..2d14249855c9 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,6 +41,7 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -54,6 +55,7 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index f34aefca4eb1..26b95c06789f 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkConfWithEnv, Utils} class SecurityManagerSuite extends SparkFunSuite { @@ -125,6 +125,47 @@ class SecurityManagerSuite extends SparkFunSuite { } + test("set security with * in acls") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "*") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAcls with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for modifyAcls with * + securityManager.setModifyAcls(Set("user4"), "*") + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + + securityManager.setAdminAcls("user1,user2") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions("user6") === false) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for adminAcls with * + securityManager.setAdminAcls("user1,*") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() val expectedAlgorithms = Set( @@ -182,5 +223,26 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.hostnameVerifier.isDefined === false) } + test("missing secret authentication key") { + val conf = new SparkConf().set("spark.authenticate", "true") + intercept[IllegalArgumentException] { + new SecurityManager(conf) + } + } + + test("secret authentication key") { + val key = "very secret key" + val conf = new SparkConf() + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, key) + assert(key === new SecurityManager(conf).getSecretKey()) + + val keyFromEnv = "very secret key from env" + val conf2 = new SparkConfWithEnv(Map(SecurityManager.ENV_AUTH_SECRET -> keyFromEnv)) + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, key) + assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) + } + } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d91b799ecfc0..0de10ae48537 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark +import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier} + import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair @@ -247,11 +251,13 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) - .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() sc = new SparkContext(myConf) + val diskBlockManager = sc.env.blockManager.diskBlockManager try { - sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() + assert(diskBlockManager.getAllFiles().isEmpty) + sc.parallelize(0 until 10).map(i => (i / 4, i)).groupByKey().collect() + assert(diskBlockManager.getAllFiles().nonEmpty) } catch { case e: Exception => val errMsg = s"Failed with spark.shuffle.spill.compress=$shuffleSpillCompress," + @@ -315,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(metrics.bytesWritten === metrics.byresRead) assert(metrics.bytesWritten > 0) } + + test("multiple simultaneous attempts for one task (SPARK-8029)") { + sc = new SparkContext("local", "test", conf) + val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val manager = sc.env.shuffleManager + + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L) + val metricsSystem = sc.env.metricsSystem + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + + // first attempt -- its successful + val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val data1 = (1 to 10).map { x => x -> x} + + // second attempt -- also successful. We'll write out different data, + // just to simulate the fact that the records may get written differently + // depending on what gets spilled, what gets combined, etc. + val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val data2 = (11 to 20).map { x => x -> x} + + // interleave writes of both attempts -- we want to test that both attempts can occur + // simultaneously, and everything is still OK + + def writeAndClose( + writer: ShuffleWriter[Int, Int])( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + val files = writer.write(iter) + writer.stop(true) + } + val interleaver = new InterleaveIterators( + data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + val (mapOutput1, mapOutput2) = interleaver.run() + + // check that we can read the map output and it has the right data + assert(mapOutput1.isDefined) + assert(mapOutput2.isDefined) + assert(mapOutput1.get.location === mapOutput2.get.location) + assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + + // register one of the map outputs -- doesn't matter which one + mapOutput1.foreach { case mapStatus => + mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + } + + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val readData = reader.read().toIndexedSeq + assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) + + manager.unregisterShuffle(0) + } +} + +/** + * Utility to help tests make sure that we can process two different iterators simultaneously + * in different threads. This makes sure that in your test, you don't completely process data1 with + * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only + * process one element, before pausing to wait for the other function to "catch up". + */ +class InterleaveIterators[T, R]( + data1: Seq[T], + f1: Iterator[T] => R, + data2: Seq[T], + f2: Iterator[T] => R) { + + require(data1.size == data2.size) + + val barrier = new CyclicBarrier(2) + class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] { + def hasNext: Boolean = sub.hasNext + + def next: E = { + barrier.await() + sub.next() + } + } + + val c1 = new Callable[R] { + override def call(): R = f1(new BarrierIterator(1, data1.iterator)) + } + val c2 = new Callable[R] { + override def call(): R = f2(new BarrierIterator(2, data2.iterator)) + } + + val e: ExecutorService = Executors.newFixedThreadPool(2) + + def run(): (R, R) = { + val future1 = e.submit(c1) + val future2 = e.submit(c2) + val r1 = future1.get() + val r2 = future2.get() + e.shutdown() + (r1, r2) + } } object ShuffleSuite { diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 000000000000..01694a6e6f74 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 63358172ea1f..b8ab227517cc 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,13 +17,78 @@ package org.apache.spark +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + private var tempDir: File = _ + override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + conf.set("spark.local.dir", tempDir.getAbsolutePath) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the new serialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = { + def getAllFiles: Set[File] = + FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 90cb7da94e88..ff9a92cc0a42 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} @@ -148,7 +149,6 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } test("Thread safeness - SPARK-5425") { - import scala.collection.JavaConversions._ val executor = Executors.newSingleThreadScheduledExecutor() val sf = executor.scheduleAtFixedRate(new Runnable { override def run(): Unit = @@ -163,8 +163,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } finally { executor.shutdownNow() - for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) - System.getProperties.remove(key) + val sysProps = System.getProperties + for (key <- sysProps.stringPropertyNames().asScala if key.startsWith("spark.5425.")) + sysProps.remove(key) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index e5a14a69ef05..d18e0782c039 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -175,6 +175,11 @@ class SparkContextSchedulerCreationSuite } test("mesos with zookeeper") { + testMesos("mesos://zk://localhost:1234,localhost:2345", + classOf[MesosSchedulerBackend], coarse = false) + } + + test("mesos with zookeeper and Master URL starting with zk://") { testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5c57940fa5f7..d4f2ea87650a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -285,4 +285,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d2529..5483f2b8434a 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 48509f0759a3..54c131cdae36 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -119,27 +119,35 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } sem.acquire(2) + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -150,13 +158,19 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -164,6 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -171,14 +186,20 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -186,50 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } + threadTestValue = sc.getLocalProperty("test") } catch { - case s: SparkException => // ignored so that we don't print noise in test logs + case t: Throwable => + throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) + } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t } + } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 48e74f06f79b..ba21075ce6be 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -310,8 +310,14 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) - _sc + try { + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) + _sc + } catch { + case e: Throwable => + _sc.stop() + throw e + } } else { new SparkContext("local", "test", broadcastConf) } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 967aa0976f0c..3164760b08a7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -31,8 +31,9 @@ private[deploy] object DeployTestUtils { } def createAppInfo() : ApplicationInfo = { + val appDesc = createAppDesc() val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + "id", appDesc, JsonConstants.submitDate, null, Int.MaxValue) appInfo.endTime = JsonConstants.currTimeInMillis appInfo } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index 823050b0aabb..d93febcfd23f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -19,6 +19,10 @@ package org.apache.spark.deploy import java.io.{File, FileInputStream, FileOutputStream} import java.util.jar.{JarEntry, JarOutputStream} +import java.util.jar.Attributes.Name +import java.util.jar.Manifest + +import scala.collection.mutable.ArrayBuffer import com.google.common.io.{Files, ByteStreams} @@ -35,7 +39,7 @@ private[deploy] object IvyTestUtils { * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` * or `pom`. */ - private def pathFromCoordinate( + private[deploy] def pathFromCoordinate( artifact: MavenCoordinate, prefix: File, ext: String, @@ -52,7 +56,7 @@ private[deploy] object IvyTestUtils { } /** Returns the artifact naming based on standard ivy or maven format. */ - private def artifactName( + private[deploy] def artifactName( artifact: MavenCoordinate, useIvyLayout: Boolean, ext: String = ".jar"): String = { @@ -73,7 +77,7 @@ private[deploy] object IvyTestUtils { } /** Write the contents to a file to the supplied directory. */ - private def writeFile(dir: File, fileName: String, contents: String): File = { + private[deploy] def writeFile(dir: File, fileName: String, contents: String): File = { val outputFile = new File(dir, fileName) val outputStream = new FileOutputStream(outputFile) outputStream.write(contents.toCharArray.map(_.toByte)) @@ -90,6 +94,42 @@ private[deploy] object IvyTestUtils { writeFile(dir, "mylib.py", contents) } + /** Create an example R package that calls the given Java class. */ + private def createRFiles( + dir: File, + className: String, + packageName: String): Seq[(String, File)] = { + val rFilesDir = new File(dir, "R" + File.separator + "pkg") + Files.createParentDirs(new File(rFilesDir, "R" + File.separator + "mylib.R")) + val contents = + s"""myfunc <- function(x) { + | SparkR:::callJStatic("$packageName.$className", "myFunc", x) + |} + """.stripMargin + val source = writeFile(new File(rFilesDir, "R"), "mylib.R", contents) + val description = + """Package: sparkPackageTest + |Type: Package + |Title: Test for building an R package + |Version: 0.1 + |Date: 2015-07-08 + |Author: Burak Yavuz + |Imports: methods, SparkR + |Depends: R (>= 3.1), methods, SparkR + |Suggests: testthat + |Description: Test for building an R package within a jar + |License: Apache License (== 2.0) + |Collate: 'mylib.R' + """.stripMargin + val descFile = writeFile(rFilesDir, "DESCRIPTION", description) + val namespace = + """import(SparkR) + |export("myfunc") + """.stripMargin + val nameFile = writeFile(rFilesDir, "NAMESPACE", namespace) + Seq(("R/pkg/R/mylib.R", source), ("R/pkg/DESCRIPTION", descFile), ("R/pkg/NAMESPACE", nameFile)) + } + /** Create a simple testable Class. */ private def createJavaClass(dir: File, className: String, packageName: String): File = { val contents = @@ -97,17 +137,14 @@ private[deploy] object IvyTestUtils { | |import java.lang.Integer; | - |class $className implements java.io.Serializable { - | - | public $className() {} - | - | public Integer myFunc(Integer x) { + |public class $className implements java.io.Serializable { + | public static Integer myFunc(Integer x) { | return x + 1; | } |} """.stripMargin val sourceFile = - new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) + new JavaSourceFromString(new File(dir, className).getAbsolutePath, contents) createCompiledClass(className, dir, sourceFile, Seq.empty) } @@ -199,14 +236,25 @@ private[deploy] object IvyTestUtils { } /** Create the jar for the given maven coordinate, using the supplied files. */ - private def packJar( + private[deploy] def packJar( dir: File, artifact: MavenCoordinate, files: Seq[(String, File)], - useIvyLayout: Boolean): File = { + useIvyLayout: Boolean, + withR: Boolean, + withManifest: Option[Manifest] = None): File = { val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) - val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) + val manifest = withManifest.getOrElse { + val mani = new Manifest() + if (withR) { + val attr = mani.getMainAttributes + attr.put(Name.MANIFEST_VERSION, "1.0") + attr.put(new Name("Spark-HasRPackage"), "true") + } + mani + } + val jarStream = new JarOutputStream(jarFileStream, manifest) for (file <- files) { val jarEntry = new JarEntry(file._1) @@ -239,7 +287,8 @@ private[deploy] object IvyTestUtils { dependencies: Option[Seq[MavenCoordinate]] = None, tempDir: Option[File] = None, useIvyLayout: Boolean = false, - withPython: Boolean = false): File = { + withPython: Boolean = false, + withR: Boolean = false): File = { // Where the root of the repository exists, and what Ivy will search in val tempPath = tempDir.getOrElse(Files.createTempDir()) // Create directory if it doesn't exist @@ -255,14 +304,16 @@ private[deploy] object IvyTestUtils { val javaClass = createJavaClass(root, className, artifact.groupId) // A tuple of files representation in the jar, and the file val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) - val allFiles = - if (withPython) { - val pythonFile = createPythonFile(root) - Seq(javaFile, (pythonFile.getName, pythonFile)) - } else { - Seq(javaFile) - } - val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout) + val allFiles = ArrayBuffer[(String, File)](javaFile) + if (withPython) { + val pythonFile = createPythonFile(root) + allFiles.append((pythonFile.getName, pythonFile)) + } + if (withR) { + val rFiles = createRFiles(root, className, artifact.groupId) + allFiles.append(rFiles: _*) + } + val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout, withR) assert(jarFile.exists(), "Problem creating Jar file") val descriptor = createDescriptor(tempPath, artifact, dependencies, useIvyLayout) assert(descriptor.exists(), "Problem creating Pom file") @@ -286,9 +337,10 @@ private[deploy] object IvyTestUtils { dependencies: Option[String], rootDir: Option[File], useIvyLayout: Boolean = false, - withPython: Boolean = false): File = { + withPython: Boolean = false, + withR: Boolean = false): File = { val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) - val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) + val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython, withR) deps.foreach { seq => seq.foreach { dep => createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) }} @@ -311,11 +363,12 @@ private[deploy] object IvyTestUtils { rootDir: Option[File], useIvyLayout: Boolean = false, withPython: Boolean = false, + withR: Boolean = false, ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, - withPython) + withPython, withR) try { f(repo.toURI.toString) } finally { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index cbd2aee10c0e..8dd31b4b6fdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy import java.net.URL -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { @@ -54,17 +54,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf().set( + val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala new file mode 100644 index 000000000000..cc30ba223e1c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, OutputStream, File} +import java.net.URI +import java.util.jar.Attributes.Name +import java.util.jar.{JarFile, Manifest} +import java.util.zip.ZipFile + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.api.r.RUtils +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.ResetSystemProperties + +class RPackageUtilsSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ResetSystemProperties { + + private val main = MavenCoordinate("a", "b", "c") + private val dep1 = MavenCoordinate("a", "dep1", "c") + private val dep2 = MavenCoordinate("a", "dep2", "d") + + private def getJarPath(coord: MavenCoordinate, repo: File): File = { + new File(IvyTestUtils.pathFromCoordinate(coord, repo, "jar", useIvyLayout = false), + IvyTestUtils.artifactName(coord, useIvyLayout = false, ".jar")) + } + + private val lineBuffer = ArrayBuffer[String]() + + private val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + private class BufferPrintStream extends PrintStream(noOpOutputStream) { + // scalastyle:off println + override def println(line: String) { + // scalastyle:on println + lineBuffer += line + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + System.setProperty("spark.testing", "true") + lineBuffer.clear() + } + + test("pick which jars to unpack using the manifest") { + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo))))) + assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") + assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") + assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + } + } + + test("build an R package from a jar end to end") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map { c => + getJarPath(c, new File(new URI(repo))) + }.mkString(",") + RPackageUtils.checkAndBuildRPackage(jars, new BufferPrintStream, verbose = true) + val firstJar = jars.substring(0, jars.indexOf(",")) + val output = lineBuffer.mkString("\n") + assert(output.contains("Building R package")) + assert(output.contains("Extracting")) + assert(output.contains(s"$firstJar contains R source code. Now installing package.")) + assert(output.contains("doesn't contain R source code, skipping...")) + } + } + + test("jars that don't exist are skipped and print warning") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map { c => + getJarPath(c, new File(new URI(repo))) + "dummy" + }.mkString(",") + RPackageUtils.checkAndBuildRPackage(jars, new BufferPrintStream, verbose = true) + val individualJars = jars.split(",") + val output = lineBuffer.mkString("\n") + individualJars.foreach { jarFile => + assert(output.contains(s"$jarFile")) + } + } + } + + test("faulty R package shows documentation") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + IvyTestUtils.withRepository(main, None, None) { repo => + val manifest = new Manifest + val attr = manifest.getMainAttributes + attr.put(Name.MANIFEST_VERSION, "1.0") + attr.put(new Name("Spark-HasRPackage"), "true") + val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil, + useIvyLayout = false, withR = false, Some(manifest)) + RPackageUtils.checkAndBuildRPackage(jar.getAbsolutePath, new BufferPrintStream, + verbose = true) + val output = lineBuffer.mkString("\n") + assert(output.contains(RPackageUtils.RJarDoc)) + } + } + + test("SparkR zipping works properly") { + val tempDir = Files.createTempDir() + try { + IvyTestUtils.writeFile(tempDir, "test.R", "abc") + val fakeSparkRDir = new File(tempDir, "SparkR") + assert(fakeSparkRDir.mkdirs()) + IvyTestUtils.writeFile(fakeSparkRDir, "abc.R", "abc") + IvyTestUtils.writeFile(fakeSparkRDir, "DESCRIPTION", "abc") + IvyTestUtils.writeFile(tempDir, "package.zip", "abc") // fake zip file :) + val fakePackageDir = new File(tempDir, "packageTest") + assert(fakePackageDir.mkdirs()) + IvyTestUtils.writeFile(fakePackageDir, "def.R", "abc") + IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") + val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") + assert(finalZip.exists()) + val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq + assert(entries.contains("/test.R")) + assert(entries.contains("/SparkR/abc.R")) + assert(entries.contains("/SparkR/DESCRIPTION")) + assert(!entries.contains("/package.zip")) + assert(entries.contains("/packageTest/def.R")) + assert(entries.contains("/packageTest/DESCRIPTION")) + } finally { + FileUtils.deleteDirectory(tempDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index aa78bfe30974..2626f5a16dfb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,11 +23,12 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -37,10 +38,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class SparkSubmitSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -133,6 +136,47 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("specify deploy mode through configuration") { + val clArgs = Seq( + "--master", "yarn", + "--conf", "spark.submit.deployMode=client", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + + appArgs.deployMode should be ("client") + sysProps("spark.submit.deployMode") should be ("client") + + // Both cmd line and configuration are specified, cmdline option takes the priority + val clArgs1 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--conf", "spark.submit.deployMode=client", + "-class", "org.SomeClass", + "thejar.jar" + ) + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + + appArgs1.deployMode should be ("cluster") + sysProps1("spark.submit.deployMode") should be ("cluster") + + // Neither cmdline nor configuration are specified, client mode is the default choice + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + appArgs2.deployMode should be (null) + + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + appArgs2.deployMode should be ("client") + sysProps2("spark.submit.deployMode") should be ("client") + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", @@ -147,7 +191,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -159,7 +203,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") @@ -167,7 +210,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -186,7 +229,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -207,7 +250,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -230,7 +273,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -254,9 +297,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -267,7 +310,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -278,7 +321,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -289,7 +332,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -300,7 +343,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { @@ -325,6 +368,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -338,6 +383,8 @@ class SparkSubmitSuite "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -356,12 +403,35 @@ class SparkSubmitSuite "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) } } + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. + // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log + ignore("correctly builds R packages included in a jar with --packages") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val rScriptDir = + Seq(sparkHome, "R", "pkg", "inst", "tests", "packageInAJarTest.R").mkString(File.separator) + assert(new File(rScriptDir).exists) + IvyTestUtils.withRepository(main, None, None, withR = true) { repo => + val args = Seq( + "--name", "testApp", + "--master", "local-cluster[2,1,1024]", + "--packages", main.toString, + "--repositories", repo, + "--verbose", + "--conf", "spark.ui.enabled=false", + rScriptDir) + runSparkSubmit(args) + } + } + test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars val files = "hdfs:/file1,file2" // --files @@ -477,6 +547,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 01ece1a10f46..63c346c1b890 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -95,6 +95,25 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(md.getDependencies.length === 2) } + test("excludes works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val excludes = Seq("a:b", "c:d") + excludes.foreach { e => + md.addExcludeRule(SparkSubmitUtils.createExclusion(e + ":*", new IvySettings, "default")) + } + val rules = md.getAllExcludeRules + assert(rules.length === 2) + val rule1 = rules(0).getId.getModuleId + assert(rule1.getOrganisation === "a") + assert(rule1.getName === "b") + val rule2 = rules(1).getId.getModuleId + assert(rule2.getOrganisation === "c") + assert(rule2.getName === "d") + intercept[IllegalArgumentException] { + SparkSubmitUtils.createExclusion("e:f:g:h", new IvySettings, "default") + } + } + test("ivy path works correctly") { val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") @@ -168,4 +187,15 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } + + test("exclude dependencies end to end") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(main.toString, + Some(repo), None, Seq("my.great.dep:mydep"), isTest = true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + assert(files.indexOf("my.great.dep") < 0, "Returned excluded artifact") + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 08c41a897a86..2fa795f84666 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -17,13 +17,20 @@ package org.apache.spark.deploy +import scala.collection.mutable +import scala.concurrent.duration._ + import org.mockito.Mockito.{mock, when} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.ApplicationInfo import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor @@ -33,7 +40,8 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterE class StandaloneDynamicAllocationSuite extends SparkFunSuite with LocalSparkContext - with BeforeAndAfterAll { + with BeforeAndAfterAll + with PrivateMethodTester { private val numWorkers = 2 private val conf = new SparkConf() @@ -56,6 +64,10 @@ class StandaloneDynamicAllocationSuite } master = makeMaster() workers = makeWorkers(10, 2048) + // Wait until all workers register with master successfully + eventually(timeout(60.seconds), interval(10.millis)) { + assert(getMasterState.workers.size === numWorkers) + } } override def afterAll(): Unit = { @@ -73,167 +85,208 @@ class StandaloneDynamicAllocationSuite test("dynamic allocation default behavior") { sc = new SparkContext(appConf) val appId = sc.applicationId - assert(master.apps.size === 1) - assert(master.apps.head.id === appId) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === Int.MaxValue) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } // kill all executors assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + var apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) // request 1 more assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === 2) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 2) // request 1 more; this one won't go through assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === 3) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 3) // kill all existing executors; we should end up with 3 - 2 = 1 executor assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) // kill all executors again; this time we'll have 1 - 1 = 0 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request many more; this increases the limit well beyond the cluster capacity assert(sc.requestExecutors(1000)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === 1000) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 1000) } test("dynamic allocation with max cores <= cores per worker") { sc = new SparkContext(appConf.set("spark.cores.max", "8")) val appId = sc.applicationId - assert(master.apps.size === 1) - assert(master.apps.head.id === appId) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) - assert(master.apps.head.getExecutorLimit === Int.MaxValue) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } // kill all executors assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + var apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.executors.values.head.cores === 8) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.executors.values.head.cores === 8) + assert(apps.head.getExecutorLimit === 1) // request 1 more; this one won't go through because we're already at max cores. // This highlights a limitation of using dynamic allocation with max cores WITHOUT // setting cores per executor: once an application scales down and then scales back // up, its executors may not be spread out anymore! assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 2) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 2) // request 1 more; this one also won't go through for the same reason assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 3) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 3) // kill all existing executors; we should end up with 3 - 1 = 2 executor // Note: we scheduled these executors together, so their cores should be evenly distributed assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) - assert(master.apps.head.getExecutorLimit === 2) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(apps.head.getExecutorLimit === 2) // kill all executors again; this time we'll have 1 - 1 = 0 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request many more; this increases the limit well beyond the cluster capacity assert(sc.requestExecutors(1000)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) - assert(master.apps.head.getExecutorLimit === 1000) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(apps.head.getExecutorLimit === 1000) } test("dynamic allocation with max cores > cores per worker") { sc = new SparkContext(appConf.set("spark.cores.max", "16")) val appId = sc.applicationId - assert(master.apps.size === 1) - assert(master.apps.head.id === appId) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) - assert(master.apps.head.getExecutorLimit === Int.MaxValue) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } // kill all executors assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + var apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.executors.values.head.cores === 10) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.executors.values.head.cores === 10) + assert(apps.head.getExecutorLimit === 1) // request 1 more // Note: the cores are not evenly distributed because we scheduled these executors 1 by 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toSet === Set(10, 6)) - assert(master.apps.head.getExecutorLimit === 2) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toSet === Set(10, 6)) + assert(apps.head.getExecutorLimit === 2) // request 1 more; this one won't go through assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === 3) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 3) // kill all existing executors; we should end up with 3 - 2 = 1 executor assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.executors.values.head.cores === 10) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.executors.values.head.cores === 10) + assert(apps.head.getExecutorLimit === 1) // kill all executors again; this time we'll have 1 - 1 = 0 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request many more; this increases the limit well beyond the cluster capacity assert(sc.requestExecutors(1000)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) - assert(master.apps.head.getExecutorLimit === 1000) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) + assert(apps.head.getExecutorLimit === 1000) } test("dynamic allocation with cores per executor") { sc = new SparkContext(appConf.set("spark.executor.cores", "2")) val appId = sc.applicationId - assert(master.apps.size === 1) - assert(master.apps.head.id === appId) - assert(master.apps.head.executors.size === 10) // 20 cores total - assert(master.apps.head.getExecutorLimit === Int.MaxValue) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 10) // 20 cores total + assert(apps.head.getExecutorLimit === Int.MaxValue) + } // kill all executors assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + var apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) // request 3 more assert(sc.requestExecutors(3)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 4) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 4) // request 10 more; only 6 will go through assert(sc.requestExecutors(10)) - assert(master.apps.head.executors.size === 10) - assert(master.apps.head.getExecutorLimit === 14) + apps = getApplications() + assert(apps.head.executors.size === 10) + assert(apps.head.getExecutorLimit === 14) // kill 2 executors; we should get 2 back immediately assert(killNExecutors(sc, 2)) - assert(master.apps.head.executors.size === 10) - assert(master.apps.head.getExecutorLimit === 12) + apps = getApplications() + assert(apps.head.executors.size === 10) + assert(apps.head.getExecutorLimit === 12) // kill 4 executors; we should end up with 12 - 4 = 8 executors assert(killNExecutors(sc, 4)) - assert(master.apps.head.executors.size === 8) - assert(master.apps.head.getExecutorLimit === 8) + apps = getApplications() + assert(apps.head.executors.size === 8) + assert(apps.head.getExecutorLimit === 8) // kill all executors; this time we'll have 8 - 8 = 0 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request many more; this increases the limit well beyond the cluster capacity assert(sc.requestExecutors(1000)) - assert(master.apps.head.executors.size === 10) - assert(master.apps.head.getExecutorLimit === 1000) + apps = getApplications() + assert(apps.head.executors.size === 10) + assert(apps.head.getExecutorLimit === 1000) } test("dynamic allocation with cores per executor AND max cores") { @@ -241,46 +294,152 @@ class StandaloneDynamicAllocationSuite .set("spark.executor.cores", "2") .set("spark.cores.max", "8")) val appId = sc.applicationId - assert(master.apps.size === 1) - assert(master.apps.head.id === appId) - assert(master.apps.head.executors.size === 4) // 8 cores total - assert(master.apps.head.getExecutorLimit === Int.MaxValue) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 4) // 8 cores total + assert(apps.head.getExecutorLimit === Int.MaxValue) + } // kill all executors assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + var apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request 1 assert(sc.requestExecutors(1)) - assert(master.apps.head.executors.size === 1) - assert(master.apps.head.getExecutorLimit === 1) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) // request 3 more assert(sc.requestExecutors(3)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 4) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 4) // request 10 more; none will go through assert(sc.requestExecutors(10)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 14) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 14) // kill all executors; 4 executors will be launched immediately assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 10) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 10) // ... and again assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 6) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 6) // ... and again; now we end up with 6 - 4 = 2 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 2) - assert(master.apps.head.getExecutorLimit === 2) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 2) // ... and again; this time we have 2 - 2 = 0 executors left assert(killAllExecutors(sc)) - assert(master.apps.head.executors.size === 0) - assert(master.apps.head.getExecutorLimit === 0) + apps = getApplications() + assert(apps.head.executors.size === 0) + assert(apps.head.getExecutorLimit === 0) // request many more; this increases the limit well beyond the cluster capacity assert(sc.requestExecutors(1000)) - assert(master.apps.head.executors.size === 4) - assert(master.apps.head.getExecutorLimit === 1000) + apps = getApplications() + assert(apps.head.executors.size === 4) + assert(apps.head.getExecutorLimit === 1000) + } + + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + val apps = getApplications() + assert(apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(apps.head.getExecutorLimit === 1) + } + + test("the pending replacement executors should not be lost (SPARK-10515)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + val executors = getExecutorIds(sc) + assert(executors.size === 2) + // kill executor 1, and replace it + assert(sc.killAndReplaceExecutor(executors.head)) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.executors.size === 2) + } + + var apps = getApplications() + // kill executor 1 + assert(sc.killExecutor(executors.head)) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 2) + // kill executor 2 + assert(sc.killExecutor(executors(1))) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) + } + + test("disable force kill for busy executors (SPARK-9552)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + var apps = getApplications() + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + val executors = getExecutorIds(sc) + assert(executors.size === 2) + + // simulate running a task on the executor + val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount) + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val executorIdToTaskCount = taskScheduler invokePrivate getMap() + executorIdToTaskCount(executors.head) = 1 + // kill the busy executor without force; this should fail + assert(killExecutor(sc, executors.head, force = false)) + apps = getApplications() + assert(apps.head.executors.size === 2) + + // force kill busy executor + assert(killExecutor(sc, executors.head, force = true)) + apps = getApplications() + // kill executor successfully + assert(apps.head.executors.size === 1) + } // =============================== @@ -313,6 +472,16 @@ class StandaloneDynamicAllocationSuite } } + /** Get the Master state */ + private def getMasterState: MasterStateResponse = { + master.self.askWithRetry[MasterStateResponse](RequestMasterState) + } + + /** Get the applictions that are active from Master */ + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + /** Kill all executors belonging to this application. */ private def killAllExecutors(sc: SparkContext): Boolean = { killNExecutors(sc, Int.MaxValue) @@ -324,6 +493,16 @@ class StandaloneDynamicAllocationSuite sc.killExecutors(getExecutorIds(sc).take(n)) } + /** Kill the given executor, specifying whether to force kill it. */ + private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Boolean = { + syncExecutors(sc) + sc.schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = false, force) + case _ => fail("expected coarse grained scheduler") + } + } + /** * Return a list of executor IDs belonging to this application. * @@ -332,8 +511,11 @@ class StandaloneDynamicAllocationSuite * don't wait for executors to register. Otherwise the tests will take much longer to run. */ private def getExecutorIds(sc: SparkContext): Seq[String] = { - assert(master.idToApp.contains(sc.applicationId)) - master.idToApp(sc.applicationId).executors.keys.map(_.toString).toSeq + val app = getApplications().find(_.id == sc.applicationId) + assert(app.isDefined) + // Although executors is transient, master is in the same process so the message won't be + // serialized and it's safe here. + app.get.executors.keys.map(_.toString).toSeq } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala new file mode 100644 index 000000000000..1e5c05a73f8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, Master} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.Utils + +/** + * End-to-end tests for application client in standalone mode. + */ +class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { + private val numWorkers = 2 + private val conf = new SparkConf() + private val securityManager = new SecurityManager(conf) + + private var masterRpcEnv: RpcEnv = null + private var workerRpcEnvs: Seq[RpcEnv] = null + private var master: Master = null + private var workers: Seq[Worker] = null + + /** + * Start the local cluster. + * Note: local-cluster mode is insufficient because we want a reference to the Master. + */ + override def beforeAll(): Unit = { + super.beforeAll() + masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) + workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) + } + master = makeMaster() + workers = makeWorkers(10, 2048) + // Wait until all workers register with master successfully + eventually(timeout(60.seconds), interval(10.millis)) { + assert(getMasterState.workers.size === numWorkers) + } + } + + override def afterAll(): Unit = { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + super.afterAll() + } + + test("interface methods of AppClient using local Master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + ci.client.start() + + // Client should connect with one Master which registers the application + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.connectedIdList.size === 1, "client listener should have one connection") + assert(apps.size === 1, "master should have 1 registered app") + } + + // Send message to Master to request Executors, verify request by change in executor limit + val numExecutorsRequested = 1 + assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") + } + + // Send request to kill executor, verify request was made + assert { + val apps = getApplications() + val executorId: String = apps.head.executors.head._2.fullId + ci.client.killExecutors(Seq(executorId)) + } + + // Issue stop command for Client to disconnect from Master + ci.client.stop() + + // Verify Client is marked dead and unregistered from Master + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.deadReasonList.size === 1, "client should have been marked dead") + assert(apps.isEmpty, "master should have 0 registered apps") + } + } + + test("request from AppClient before initialized with master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + // requests to master should fail immediately + assert(ci.client.requestTotalExecutors(3) === false) + } + + // =============================== + // | Utility methods for testing | + // =============================== + + /** Return a SparkConf for applications that want to talk to our Master. */ + private def appConf: SparkConf = { + new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .set("spark.executor.memory", "256m") + } + + /** Make a master to which our application will send executor requests. */ + private def makeMaster(): Master = { + val master = new Master(masterRpcEnv, masterRpcEnv.address, 0, securityManager, conf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + /** Make a few workers that talk to our master. */ + private def makeWorkers(cores: Int, memory: Int): Seq[Worker] = { + (0 until numWorkers).map { i => + val rpcEnv = workerRpcEnvs(i) + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + worker + } + } + + /** Get the Master state */ + private def getMasterState: MasterStateResponse = { + master.self.askWithRetry[MasterStateResponse](RequestMasterState) + } + + /** Get the applictions that are active from Master */ + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + /** Application Listener to collect events */ + private class AppClientCollector extends AppClientListener with Logging { + val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + @volatile var disconnectedCount: Int = 0 + val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + + def connected(id: String): Unit = { + connectedIdList += id + } + + def disconnected(): Unit = { + synchronized { + disconnectedCount += 1 + } + } + + def dead(reason: String): Unit = { + deadReasonList += reason + } + + def executorAdded( + id: String, + workerId: String, + hostPort: String, + cores: Int, + memory: Int): Unit = { + execAddedList += id + } + + def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + execRemovedList += id + } + } + + /** Create AppClient and supporting objects */ + private class AppClientInst(masterUrl: String) { + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, securityManager) + private val cmd = new Command(TestExecutor.getClass.getCanonicalName.stripSuffix("$"), + List(), Map(), Seq(), Seq(), Seq()) + private val desc = new ApplicationDescription("AppClientSuite", Some(1), 512, cmd, "ignored") + val listener = new AppClientCollector + val client = new AppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 73cff89544dc..5cab17f8a38f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -24,18 +24,24 @@ import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.io.Source +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.Path +import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ +import org.mockito.Matchers.any +import org.mockito.Mockito.{doReturn, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ import org.apache.spark.scheduler._ -import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} +import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -407,6 +413,53 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("provider correctly checks whether fs is in safe mode") { + val provider = spy(new FsHistoryProvider(createTestConf())) + val dfs = mock(classOf[DistributedFileSystem]) + // Asserts that safe mode is false because we can't really control the return value of the mock, + // since the API is different between hadoop 1 and 2. + assert(!provider.isFsInSafeMode(dfs)) + } + + test("provider waits for safe mode to finish before initializing") { + val clock = new ManualClock() + val provider = new SafeModeTestProvider(createTestConf(), clock) + val initThread = provider.initialize() + try { + provider.getConfig().keys should contain ("HDFS State") + + clock.setTime(5000) + provider.getConfig().keys should contain ("HDFS State") + + provider.inSafeMode = false + clock.setTime(10000) + + eventually(timeout(1 second), interval(10 millis)) { + provider.getConfig().keys should not contain ("HDFS State") + } + } finally { + provider.stop() + } + } + + test("provider reports error after FS leaves safe mode") { + testDir.delete() + val clock = new ManualClock() + val provider = new SafeModeTestProvider(createTestConf(), clock) + val errorHandler = mock(classOf[Thread.UncaughtExceptionHandler]) + val initThread = provider.startSafeModeCheckThread(Some(errorHandler)) + try { + provider.inSafeMode = false + clock.setTime(10000) + + eventually(timeout(1 second), interval(10 millis)) { + verify(errorHandler).uncaughtException(any(), any()) + } + } finally { + provider.stop() + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -465,4 +518,16 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc log } + private class SafeModeTestProvider(conf: SparkConf, clock: Clock) + extends FsHistoryProvider(conf, clock) { + + @volatile var inSafeMode = true + + // Skip initialization so that we can manually start the safe mode check thread. + private[history] override def initialize(): Thread = null + + private[history] override def isFsInSafeMode(): Boolean = inSafeMode + + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala new file mode 100644 index 000000000000..34f27ecaa07a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.history + +import java.io.File +import java.nio.charset.StandardCharsets._ + +import com.google.common.io.Files + +import org.apache.spark._ +import org.apache.spark.util.Utils + +class HistoryServerArgumentsSuite extends SparkFunSuite { + + private val logDir = new File("src/test/resources/spark-events") + private val conf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set("spark.history.fs.updateInterval", "1") + .set("spark.testing", "true") + + test("No Arguments Parsing") { + val argStrings = Array[String]() + val hsa = new HistoryServerArguments(conf, argStrings) + assert(conf.get("spark.history.fs.logDirectory") === logDir.getAbsolutePath) + assert(conf.get("spark.history.fs.updateInterval") === "1") + assert(conf.get("spark.testing") === "true") + } + + test("Directory Arguments Parsing --dir or -d") { + val argStrings = Array("--dir", "src/test/resources/spark-events1") + val hsa = new HistoryServerArguments(conf, argStrings) + assert(conf.get("spark.history.fs.logDirectory") === "src/test/resources/spark-events1") + } + + test("Directory Param can also be set directly") { + val argStrings = Array("src/test/resources/spark-events2") + val hsa = new HistoryServerArguments(conf, argStrings) + assert(conf.get("spark.history.fs.logDirectory") === "src/test/resources/spark-events2") + } + + test("Properties File Arguments Parsing --properties-file") { + val tmpDir = Utils.createTempDir() + val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) + try { + Files.write("spark.test.CustomPropertyA blah\n" + + "spark.test.CustomPropertyB notblah\n", outFile, UTF_8) + val argStrings = Array("--properties-file", outFile.getAbsolutePath) + val hsa = new HistoryServerArguments(conf, argStrings) + assert(conf.get("spark.test.CustomPropertyA") === "blah") + assert(conf.get("spark.test.CustomPropertyB") === "notblah") + } finally { + Utils.deleteRecursively(tmpDir) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e5b5e1bb6533..4b7fd4f13b69 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -261,7 +261,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers l <- links attrs <- l.attribute("href") } yield (attrs.toString) - justHrefs should contain(link) + justHrefs should contain (UIUtils.prependBaseUri(resource = link)) + } + + test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") + val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") + val page = new HistoryPage(server) + val request = mock[HttpServletRequest] + + // when + System.setProperty("spark.ui.proxyBase", uiRoot) + val response = page.render(request) + System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse("")) + + // then + val urls = response \\ "@href" map (_.toString) + val siteRelativeLinks = urls filter (_.startsWith("/")) + all (siteRelativeLinks) should startWith (uiRoot) } def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index 8c96b0e71dfd..4b86da536768 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -99,7 +99,7 @@ object CustomPersistenceEngine { @volatile var lastInstance: Option[CustomPersistenceEngine] = None } -class CustomLeaderElectionAgent(val masterActor: LeaderElectable) extends LeaderElectionAgent { - masterActor.electedLeader() +class CustomLeaderElectionAgent(val masterInstance: LeaderElectable) extends LeaderElectionAgent { + masterInstance.electedLeader() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 30780a0da7f8..242bf4b5566e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -40,6 +40,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts @@ -93,8 +94,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva publicAddress = "" ) - val (rpcEnv, uiPort, restPort) = - Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, _, _) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) @@ -151,6 +152,14 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva basicScheduling(spreadOut = false) } + test("basic scheduling with more memory - spread out") { + basicSchedulingWithMoreMemory(spreadOut = true) + } + + test("basic scheduling with more memory - no spread out") { + basicSchedulingWithMoreMemory(spreadOut = false) + } + test("scheduling with max cores - spread out") { schedulingWithMaxCores(spreadOut = true) } @@ -214,6 +223,13 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva assert(scheduledCores === Array(10, 10, 10)) } + private def basicSchedulingWithMoreMemory(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(3072) + val scheduledCores = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores === Array(10, 10, 10)) + } + private def schedulingWithMaxCores(spreadOut: Boolean): Unit = { val master = makeMaster() val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) @@ -343,8 +359,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private def makeMaster(conf: SparkConf = new SparkConf): Master = { val securityMgr = new SecurityManager(conf) - val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) master } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 11e87bd1dd8e..7a4472867568 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -63,55 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = rpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - rpcEnv.shutdown() - rpcEnv.awaitTermination() + persistenceEngine.close() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala new file mode 100644 index 000000000000..fba835f054f8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date + +import scala.io.Source +import scala.language.postfixOps + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployTestUtils._ +import org.apache.spark.deploy.master._ +import org.apache.spark.rpc.RpcEnv + + +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { + + val masterPage = mock(classOf[MasterPage]) + val master = { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + + before { + masterWebUI.bind() + } + + after { + masterWebUI.stop() + } + + test("list applications") { + val worker = createWorkerInfo() + val appDesc = createAppDesc() + // use new start date so it isn't filtered by UI + val activeApp = new ApplicationInfo( + new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) + activeApp.addExecutor(worker, 2) + + val workers = Array[WorkerInfo](worker) + val activeApps = Array(activeApp) + val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array[DriverInfo]() + val completedDrivers = Array[DriverInfo]() + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, + activeDrivers, completedDrivers, RecoveryState.ALIVE) + + when(masterPage.getMasterState).thenReturn(stateResponse) + + val resultJson = Source.fromURL( + s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") + .mkString + val parsedJson = parse(resultJson) + val firstApp = parsedJson(0) + + assert(firstApp \ "id" === JString(activeApp.id)) + assert(firstApp \ "name" === JString(activeApp.desc.name)) + assert(firstApp \ "coresGranted" === JInt(2)) + assert(firstApp \ "maxCores" === JInt(4)) + assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) + assert(firstApp \ "coresPerExecutor" === JNothing) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 96e456d889ac..9693e32bf6af 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -366,6 +366,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) } + test("client does not send 'SPARK_ENV_LOADED' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_ENV_LOADED" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client includes mesos env vars") { + val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1")) + } + /* --------------------- * | Helper methods | * --------------------- */ diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index bed6f3ea6124..98664dc1101e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.worker import java.io.File -import scala.collection.JavaConversions._ - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -36,6 +34,7 @@ class ExecutorRunnerTest extends SparkFunSuite { ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) - assert(builder.command().last === appId) + val builderCommand = builder.command() + assert(builderCommand.get(builderCommand.size() - 1) === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 15f7ca4a6dac..637e78fda019 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} - +import org.apache.spark.util.SparkConfWithEnv class WorkerArgumentsTest extends SparkFunSuite { @@ -34,18 +34,7 @@ class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { val args = Array("spark://localhost:0000 ") - - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_WORKER_MEMORY") "50000" - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf() + val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "50000")) intercept[IllegalStateException] { new WorkerArguments(args, conf) } @@ -53,18 +42,7 @@ class WorkerArgumentsTest extends SparkFunSuite { test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { val args = Array("spark://localhost:0000 ") - - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_WORKER_MEMORY") "5G" - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf() + val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "5G")) val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index cd24d7942331..40c24bdecc6c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -26,8 +26,7 @@ class WorkerWatcherSuite extends SparkFunSuite { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) @@ -38,12 +37,10 @@ class WorkerWatcherSuite extends SparkFunSuite { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") - val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor" - val otherAkkaAddress = RpcAddress("4.3.2.1", 1234) - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val otherRpcAddress = RpcAddress("4.3.2.1", 1234) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected(otherAkkaAddress) + workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index cbdb33c89d0f..1553ab60bdda 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -100,12 +100,10 @@ class CompressionCodecSuite extends SparkFunSuite { testCodec(codec) } - test("snappy does not support concatenation of serialized streams") { + test("snappy supports concatenation of serialized streams") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) - intercept[Exception] { - testConcatenationOfSerializedStreams(codec) - } + testConcatenationOfSerializedStreams(codec) } test("bad compression codec") { diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala new file mode 100644 index 000000000000..639d1daa36c7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.launcher._ + +class LauncherBackendSuite extends SparkFunSuite with Matchers { + + private val tests = Seq( + "local" -> "local", + "standalone/client" -> "local-cluster[1,1,1024]") + + tests.foreach { case (name, master) => + test(s"$name: launcher handle") { + testWithMaster(master) + } + } + + private def testWithMaster(master: String): Unit = { + val env = new java.util.HashMap[String, String]() + env.put("SPARK_PRINT_LAUNCH_COMMAND", "1") + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .setConf("spark.ui.enabled", "false") + .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console") + .setMaster(master) + .setAppResource("spark-internal") + .setMainClass(TestApp.getClass.getName().stripSuffix("$")) + .startApplication() + + try { + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getAppId() should not be (null) + } + + handle.stop() + + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + +} + +object TestApp { + + def main(args: Array[String]): Unit = { + new SparkContext(new SparkConf()).parallelize(Seq(1)).foreach { i => + Thread.sleep(TimeUnit.SECONDS.toMillis(20)) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala new file mode 100644 index 000000000000..555b640cb424 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} + +import org.mockito.Matchers.{any, anyLong} +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel} + + +/** + * Helper trait for sharing code among [[MemoryManager]] tests. + */ +private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAfterEach { + + protected val evictedBlocks = new mutable.ArrayBuffer[(BlockId, BlockStatus)] + + import MemoryManagerSuite.DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED + + // Note: Mockito's verify mechanism does not provide a way to reset method call counts + // without also resetting stubbed methods. Since our test code relies on the latter, + // we need to use our own variable to track invocations of `evictBlocksToFreeSpace`. + + /** + * The amount of space requested in the last call to [[MemoryStore.evictBlocksToFreeSpace]]. + * + * This set whenever [[MemoryStore.evictBlocksToFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through + * [[assertEvictBlocksToFreeSpaceCalled]]. + */ + private val evictBlocksToFreeSpaceCalled = new AtomicLong(0) + + override def beforeEach(): Unit = { + super.beforeEach() + evictedBlocks.clear() + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } + + /** + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is stubbed. + * + * This allows our test code to release storage memory when these methods are called + * without relying on [[org.apache.spark.storage.BlockManager]] and all of its dependencies. + */ + protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) + mm.setMemoryStore(ms) + ms + } + + /** + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Boolean] = { + new Answer[Boolean] { + override def answer(invocation: InvocationOnMock): Boolean = { + val args = invocation.getArguments + val numBytesToFree = args(1).asInstanceOf[Long] + assert(numBytesToFree > 0) + require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "bad test: evictBlocksToFreeSpace() variable was not reset") + evictBlocksToFreeSpaceCalled.set(numBytesToFree) + if (numBytesToFree <= mm.storageMemoryUsed) { + // We can evict enough blocks to fulfill the request for space + mm.releaseStorageMemory(numBytesToFree) + args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + // We need to add this call so that that the suite-level `evictedBlocks` is updated when + // execution evicts storage; in that case, args.last will not be equal to evictedBlocks + // because it will be a temporary buffer created inside of the MemoryManager rather than + // being passed in by the test code. + if (!(evictedBlocks eq args.last)) { + evictedBlocks.append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + } + true + } else { + // No blocks were evicted because eviction would not free enough space. + false + } + } + } + } + + /** + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters. + */ + protected def assertEvictBlocksToFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === numBytes, + s"expected evictBlocksToFreeSpace() to be called with $numBytes") + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } + + /** + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is NOT called. + */ + protected def assertEvictBlocksToFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "evictBlocksToFreeSpace() should not have been called!") + assert(evictedBlocks.isEmpty) + } + + /** + * Create a MemoryManager with the specified execution memory limits and no storage memory. + */ + protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long = 0L): MemoryManager + + // -- Tests of sharing of execution memory between tasks ---------------------------------------- + // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. + + implicit val ec = ExecutionContext.global + + test("single task requesting on-heap execution memory") { + val manager = createMemoryManager(1000L) + val taskMemoryManager = new TaskMemoryManager(manager, 0) + + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + + taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L) + + taskMemoryManager.cleanUpAllAllocatedMemory() + assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + } + + test("two tasks requesting full on-heap execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 500 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result1, futureTimeout) === 500L) + assert(Await.result(t2Result1, futureTimeout) === 500L) + + // Have both tasks each request 500 bytes more; both should immediately return 0 as they are + // both now at 1 / N + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result2, 200.millis) === 0L) + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("two tasks cannot grow past 1 / N of on-heap execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 250 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result1, futureTimeout) === 250L) + assert(Await.result(t2Result1, futureTimeout) === 250L) + + // Have both tasks each request 500 bytes more. + // We should only grant 250 bytes to each of them on this second request + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result2, futureTimeout) === 250L) + assert(Await.result(t2Result2, futureTimeout) === 250L) + } + + test("tasks can block to get at least 1 / 2N of on-heap execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) + // The memory freed from t1 should now be granted to t2. + assert(Await.result(t2Result1, futureTimeout) === 250L) + // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("TaskMemoryManager.cleanUpAllAllocatedMemory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + // t1 releases all of its memory, so t2 should be able to grab all of the memory + t1MemManager.cleanUpAllAllocatedMemory() + assert(Await.result(t2Result1, futureTimeout) === 500L) + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t2Result2, futureTimeout) === 500L) + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t2Result3, 200.millis) === 0L) + } + + test("tasks should not be granted a negative amount of execution memory") { + // This is a regression test for SPARK-4715. + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result1, futureTimeout) === 700L) + + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t2Result1, futureTimeout) === 300L) + + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } + assert(Await.result(t1Result2, 200.millis) === 0L) + } + + test("off-heap execution allocations cannot exceed limit") { + val memoryManager = createMemoryManager( + maxOnHeapExecutionMemory = 0L, + maxOffHeapExecutionMemory = 1000L) + + val tMemManager = new TaskMemoryManager(memoryManager, 1) + val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result1, 200.millis) === 1000L) + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + + val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result2, 200.millis) === 0L) + + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 500L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 0L) + } +} + +private object MemoryManagerSuite { + private val DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED = -1L +} diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala new file mode 100644 index 000000000000..4b4c3b031132 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} + +/** + * Helper methods for mocking out memory-management-related classes in tests. + */ +object MemoryTestingUtils { + def fakeTaskContext(env: SparkEnv): TaskContext = { + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) + new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = env.metricsSystem, + internalAccumulators = Seq.empty) + } +} diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala new file mode 100644 index 000000000000..68cf26fc3ed5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.mockito.Mockito.when + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{MemoryStore, TestBlockId} + +class StaticMemoryManagerSuite extends MemoryManagerSuite { + private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") + + /** + * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. + */ + private def makeThings( + maxExecutionMem: Long, + maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { + val mm = new StaticMemoryManager( + conf, + maxOnHeapExecutionMemory = maxExecutionMem, + maxStorageMemory = maxStorageMem, + numCores = 1) + val ms = makeMemoryStore(mm) + (mm, ms) + } + + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): StaticMemoryManager = { + new StaticMemoryManager( + conf.clone + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), + maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, + maxStorageMemory = 0, + numCores = 1) + } + + test("basic execution memory") { + val maxExecutionMem = 1000L + val taskAttemptId = 0L + val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) + assert(mm.executionMemoryUsed === 0L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) + assert(mm.executionMemoryUsed === 10L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) + // Acquire up to the max + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) + assert(mm.executionMemoryUsed === maxExecutionMem) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) + assert(mm.executionMemoryUsed === maxExecutionMem) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) + assert(mm.executionMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) + assert(mm.executionMemoryUsed === 201L) + // Release beyond what was acquired + mm.releaseExecutionMemory(maxExecutionMem, taskAttemptId, MemoryMode.ON_HEAP) + assert(mm.executionMemoryUsed === 0L) + } + + test("basic storage memory") { + val maxStorageMem = 1000L + val dummyBlock = TestBlockId("you can see the world you brought to live") + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 10L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 110L) + // Acquire more than the max, not granted + assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 110L) + // Acquire up to the max, requests after this are still granted due to LRU eviction + assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) + assert(mm.storageMemoryUsed === 1000L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. + assert(mm.storageMemoryUsed === 1000L) + mm.releaseStorageMemory(800L) + assert(mm.storageMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 201L) + mm.releaseAllStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 1L) + // Release beyond what was acquired + mm.releaseStorageMemory(100L) + assert(mm.storageMemoryUsed === 0L) + } + + test("execution and storage isolation") { + val maxExecutionMem = 200L + val maxStorageMem = 1000L + val taskAttemptId = 0L + val dummyBlock = TestBlockId("ain't nobody love like you do") + val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) + // Only execution memory should increase + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 100L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 200L) + // Only storage memory should increase + assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 50L) + assert(mm.executionMemoryUsed === 200L) + // Only execution memory should be released + mm.releaseExecutionMemory(133L, taskAttemptId, MemoryMode.ON_HEAP) + assert(mm.storageMemoryUsed === 50L) + assert(mm.executionMemoryUsed === 67L) + // Only storage memory should be released + mm.releaseAllStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 67L) + } + + test("unroll memory") { + val maxStorageMem = 1000L + val dummyBlock = TestBlockId("lonely water") + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) + assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) + when(ms.currentUnrollMemory).thenReturn(100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 100L) + mm.releaseUnrollMemory(40L) + assert(mm.storageMemoryUsed === 60L) + when(ms.currentUnrollMemory).thenReturn(60L) + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 860L) + // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. + // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. + // Requesting 240 more bytes of unroll memory will leave our total unroll memory at + // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. + assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 + when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 + assert(mm.storageMemoryUsed === 1000L) + evictedBlocks.clear() + // We already have 300 bytes of unroll memory, so requesting 150 more will leave us + // above the 400-byte limit. Since there is not enough free memory, this request will + // fail even after evicting as much as we can (400 - 300 = 100 bytes). + assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 900L) + // Release beyond what was acquired + mm.releaseUnrollMemory(maxStorageMem) + assert(mm.storageMemoryUsed === 0L) + } + +} diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala new file mode 100644 index 000000000000..0706a6e45de8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} + +class TestMemoryManager(conf: SparkConf) + extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { + + override private[memory] def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = { + if (oomOnce) { + oomOnce = false + 0 + } else if (available >= numBytes) { + available -= numBytes + numBytes + } else { + val grant = available + available = 0 + grant + } + } + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def releaseStorageMemory(numBytes: Long): Unit = {} + override private[memory] def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = { + available += numBytes + } + override def maxStorageMemory: Long = Long.MaxValue + + private var oomOnce = false + private var available = Long.MaxValue + + def markExecutionAsOutOfMemoryOnce(): Unit = { + oomOnce = true + } + + def limit(avail: Long): Unit = { + available = avail + } + +} diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala new file mode 100644 index 000000000000..6cc48597d38f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{MemoryStore, TestBlockId} + +class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { + private val dummyBlock = TestBlockId("--") + + private val storageFraction: Double = 0.5 + + /** + * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. + */ + private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { + val mm = createMemoryManager(maxMemory) + val ms = makeMemoryStore(mm) + (mm, ms) + } + + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): UnifiedMemoryManager = { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) + .set("spark.memory.storageFraction", storageFraction.toString) + UnifiedMemoryManager(conf, numCores = 1) + } + + test("basic execution memory") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, _) = makeThings(maxMemory) + assert(mm.executionMemoryUsed === 0L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) + assert(mm.executionMemoryUsed === 10L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) + // Acquire up to the max + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) + assert(mm.executionMemoryUsed === maxMemory) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) + assert(mm.executionMemoryUsed === maxMemory) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) + assert(mm.executionMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) + assert(mm.executionMemoryUsed === 201L) + // Release beyond what was acquired + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) + assert(mm.executionMemoryUsed === 0L) + } + + test("basic storage memory") { + val maxMemory = 1000L + val (mm, ms) = makeThings(maxMemory) + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 10L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 110L) + // Acquire more than the max, not granted + assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 110L) + // Acquire up to the max, requests after this are still granted due to LRU eviction + assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) + assert(mm.storageMemoryUsed === 1000L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. + assert(mm.storageMemoryUsed === 1000L) + mm.releaseStorageMemory(800L) + assert(mm.storageMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 201L) + mm.releaseAllStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 1L) + // Release beyond what was acquired + mm.releaseStorageMemory(100L) + assert(mm.storageMemoryUsed === 0L) + } + + test("execution evicts storage") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough storage memory to exceed the storage region + assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 750L) + // Execution needs to request 250 bytes to evict storage memory + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) + assert(mm.executionMemoryUsed === 100L) + assert(mm.storageMemoryUsed === 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + // Execution wants 200 bytes but only 150 are free, so storage is evicted + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) + assert(mm.executionMemoryUsed === 300L) + assert(mm.storageMemoryUsed === 700L) + assertEvictBlocksToFreeSpaceCalled(ms, 50L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + mm.releaseAllStorageMemory() + require(mm.executionMemoryUsed === 300L) + require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") + // Acquire some storage memory again, but this time keep it within the storage region + assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 400L) + assert(mm.executionMemoryUsed === 300L) + // Execution cannot evict storage because the latter is within the storage fraction, + // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 + assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) + assert(mm.executionMemoryUsed === 600L) + assert(mm.storageMemoryUsed === 400L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("execution memory requests smaller than free memory should evict storage (SPARK-12165)") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough storage memory to exceed the storage region size + assert(mm.acquireStorageMemory(dummyBlock, 700L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 700L) + // SPARK-12165: previously, MemoryStore would not evict anything because it would + // mistakenly think that the 300 bytes of free space was still available even after + // using it to expand the execution pool. Consequently, no storage memory was released + // and the following call granted only 300 bytes to execution. + assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) + assertEvictBlocksToFreeSpaceCalled(ms, 200L) + assert(mm.storageMemoryUsed === 500L) + assert(mm.executionMemoryUsed === 500L) + assert(evictedBlocks.nonEmpty) + } + + test("storage does not evict execution") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough execution memory to exceed the execution region + assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 0L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + // Storage should not be able to evict execution + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 100L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) + mm.releaseStorageMemory(maxMemory) + // Acquire some execution memory again, but this time keep it within the execution region + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 0L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + // Storage should still not be able to evict execution + assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free + assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 750L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("small heap") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong + assert(mm.maxMemory === expectedMaxMemory) + + // Try using a system memory that's too small + val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("larger heap size")) + } + + test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + val ms = makeMemoryStore(mm) + assert(mm.maxMemory === 1000) + // Have two tasks each acquire some execution memory so that the memory pool registers that + // there are two active tasks: + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assert(mm.acquireExecutionMemory(100L, 1, MemoryMode.ON_HEAP) === 100L) + // Fill up all of the remaining memory with storage. + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 800) + assert(mm.executionMemoryUsed === 200) + // A task should still be able to allocate 100 bytes execution memory by evicting blocks + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.executionMemoryUsed === 300) + assert(mm.storageMemoryUsed === 700) + assert(evictedBlocks.nonEmpty) + } + +} diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index d3218a548efc..44eb5a046912 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -286,6 +286,10 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() + + // Avoid receiving earlier taskEnd events + sc.listenerBus.waitUntilEmpty(500) + sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { collector(taskEnd).foreach(taskMetrics += _) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 3940527fb874..98da94139f7f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -148,7 +148,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } }) - Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala deleted file mode 100644 index 5e364cc0edeb..000000000000 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.nio._ - -import scala.concurrent.duration._ -import scala.concurrent.{Await, TimeoutException} -import scala.language.postfixOps - -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils - -/** - * Test the ConnectionManager with various security settings. - */ -class ConnectionManagerSuite extends SparkFunSuite { - - test("security default off") { - val conf = new SparkConf - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var receivedMessage = false - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - receivedMessage = true - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) - - assert(receivedMessage == true) - - manager.stop() - } - - test("security on same password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - conf.set("spark.app.id", "app-id") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - val managerServer = new ConnectionManager(0, conf, securityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val count = 10 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - }) - - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.app.id", "app-id") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = conf.clone.set("spark.authenticate.secret", "bad") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - // Expect managerServer to close connection, which we'll report as an error: - intercept[IOException] { - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - } - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "good") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 1).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - assert(false) - } catch { - case i: IOException => - assert(true) - case e: TimeoutException => { - // we should timeout here since the client can't do the negotiation - assert(true) - } - } - }) - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - manager.stop() - managerServer.stop() - } - - test("security auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 10).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - } catch { - case e: Exception => { - assert(false) - } - } - }) - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("Ack error message") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - val managerServer = new ConnectionManager(0, conf, securityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception("Custom exception text") - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - val exception = intercept[IOException] { - Await.result(future, 1 second) - } - assert(Utils.exceptionString(exception).contains("Custom exception text")) - - manager.stop() - managerServer.stop() - - } - - test("sendMessageReliably timeout") { - val clientConf = new SparkConf - clientConf.set("spark.authenticate", "false") - val ackTimeoutS = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") - - val clientSecurityManager = new SecurityManager(clientConf) - val manager = new ConnectionManager(0, clientConf, clientSecurityManager) - - val serverConf = new SparkConf - serverConf.set("spark.authenticate", "false") - val serverSecurityManager = new SecurityManager(serverConf) - val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeoutS * 3 * 1000) - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - // Future should throw IOException in 30 sec. - // Otherwise TimeoutExcepton is thrown from Await.result. - // We expect TimeoutException is not thrown. - intercept[IOException] { - Await.result(future, (ackTimeoutS * 2) second) - } - - manager.stop() - managerServer.stop() - } - -} - diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad6..de015ebd5d23 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 1321ec84735b..7d2cfcca9436 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.util.Progressable @@ -578,17 +579,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { + def assertBinomialSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { + if (exact) { + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new BinomialDistribution(trials, p) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } + } + } + + def assertPoissonSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { if (exact) { - return expected == actual + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new PoissonDistribution(p * trials) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -613,8 +633,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) + val trials = stratifiedData.countByKey() val fractions = Map("1" -> samplingRate, "0" -> samplingRate) val sample = if (exact) { stratifiedData.sampleByKeyExact(false, fractions, seed) @@ -623,8 +642,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } val sampleCounts = sample.countByKey() val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + sampleCounts.foreach { case (k, v) => + assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt, + p = samplingRate) + } assert(takeSample.size === takeSample.toSet.size) takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } } @@ -635,6 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { + val trials = stratifiedData.countByKey() val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -646,7 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val sampleCounts = sample.countByKey() val takeSample = sample.collect() sampleCounts.foreach { case (k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate) } val groupedByKey = takeSample.groupBy(_._1) for ((key, v) <- groupedByKey) { @@ -657,7 +679,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { assert(v.toSet.size <= expectedSampleSize(key)) } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate) } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 3e8816a4c65b..5f73ec867596 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index f65349e3e358..16a92f54f936 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -38,6 +38,13 @@ class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { sc.stop() } + test("equals and hashCode") { + val opScope1 = new RDDOperationScope("scope1", id = "1") + val opScope2 = new RDDOperationScope("scope1", id = "1") + assert(opScope1 === opScope2) + assert(opScope1.hashCode() === opScope2.hashCode()) + } + test("getAllScopes") { assert(scope1.getAllScopes === Seq(scope1)) assert(scope2.getAllScopes === Seq(scope1, scope2)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5f718ea9f7be..007a71f87cf1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) @@ -100,21 +101,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[UnionRDD[_]]) } test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) } test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) intercept[IllegalArgumentException] { new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6ceafe433774..49e3e0191c38 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -24,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -38,16 +45,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } - def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv test("send a message locally") { @volatile var message: String = null @@ -75,7 +87,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -99,7 +111,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) @@ -130,7 +141,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +169,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -328,9 +339,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def onStop(): Unit = { selfOption = Option(self) } - - override def onError(cause: Throwable): Unit = { - } }) env.stop(endpointRef) @@ -420,7 +428,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -460,7 +468,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -500,23 +508,82 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") val remoteAddress = anotherEnv.address rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List(("onConnected", remoteAddress))) + // anotherEnv is connected in client mode, so the remote address may be unknown depending on + // the implementation. Account for that when doing checks. + if (remoteAddress != null) { + assert(events === List(("onConnected", remoteAddress))) + } else { + assert(events.size === 1) + assert(events(0)._1 === "onConnected") + } + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + // Account for anotherEnv not having an address due to running in client mode. + if (remoteAddress != null) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress)) || + events === List( + ("onConnected", remoteAddress), + ("onDisconnected", remoteAddress))) + } else { + val eventNames = events.map(_._1) + assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || + eventNames === List("onConnected", "onDisconnected")) + } + } + } + + test("network events between non-client-mode RpcEnvs") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events-non-client") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) } anotherEnv.shutdown() anotherEnv.awaitTermination() eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress))) + assert(events.contains(("onConnected", remoteAddress))) + assert(events.contains(("onDisconnected", remoteAddress))) } } @@ -529,21 +596,90 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - intercept[TimeoutException] { + val e = intercept[Exception] { Await.result(f, 1 seconds) } + assert(e.isInstanceOf[TimeoutException] || // For Akka + e.isInstanceOf[NotSerializableException] // For Netty + ) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() } } + test("port conflict") { + val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) + assert(anotherEnv.address.port != env.address.port) + } + + test("send with authentication") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) + + try { + @volatile var message: String = null + localEnv.setupEndpoint("send-authentication", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + + test("ask with authentication") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) + + try { + localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val reply = rpcEndpointRef.askWithRetry[String]("hello") + assert("hello" === reply) + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + test("construct RpcTimeout with conf property") { val conf = new SparkConf @@ -612,7 +748,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // once the future is complete to verify addMessageIfTimeout was invoked val reply3 = intercept[RpcTimeoutException] { - Await.result(fut3, 200 millis) + Await.result(fut3, 2000 millis) }.getMessage // When the future timed out, the recover callback should have used @@ -630,6 +766,68 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) + val emptyUri = env.fileServer.addFile(empty) + val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } + + val destDir = Utils.createTempDir() + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + + val files = Seq( + (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), + (empty, emptyUri), + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars", "dir1").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala new file mode 100644 index 000000000000..5e8da3e205ab --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.collection.mutable.ArrayBuffer + +import org.scalactic.TripleEquals + +class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals { + + override val rpcEnv: RpcEnv = null + + @volatile private var receiveMessages = ArrayBuffer[Any]() + + @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]() + + @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]() + + @volatile private var started = false + + @volatile private var stopped = false + + override def receive: PartialFunction[Any, Unit] = { + case message: Any => receiveMessages += message + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message: Any => receiveAndReplyMessages += message + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + onConnectedMessages += remoteAddress + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + onNetworkErrorMessages += cause -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + onDisconnectedMessages += remoteAddress + } + + def numReceiveMessages: Int = receiveMessages.size + + override def onStart(): Unit = { + started = true + } + + override def onStop(): Unit = { + stopped = true + } + + def verifyStarted(): Unit = { + assert(started, "RpcEndpoint is not started") + } + + def verifyStopped(): Unit = { + assert(stopped, "RpcEndpoint is not stopped") + } + + def verifyReceiveMessages(expected: Seq[Any]): Unit = { + assert(receiveMessages === expected) + } + + def verifySingleReceiveMessage(message: Any): Unit = { + verifyReceiveMessages(List(message)) + } + + def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = { + assert(receiveAndReplyMessages === expected) + } + + def verifySingleReceiveAndReplyMessage(message: Any): Unit = { + verifyReceiveAndReplyMessages(List(message)) + } + + def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnConnectedMessages(List(remoteAddress)) + } + + def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onConnectedMessages === expected) + } + + def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnDisconnectedMessages(List(remoteAddress)) + } + + def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onDisconnectedMessages === expected) + } + + def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = { + verifyOnNetworkErrorMessages(List(cause -> remoteAddress)) + } + + def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = { + assert(onNetworkErrorMessages === expected) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 4aa75c9230b2..7aac02775e1b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + override def createRpcEnv(conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) } test("setupEndpointRef: systemName, address, endpointName") { @@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala new file mode 100644 index 000000000000..2136795b1881 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} + +class InboxSuite extends SparkFunSuite { + + test("post") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = OneWayMessage(null, "hi") + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveMessage("hi") + + inbox.stop() + inbox.process(dispatcher) + assert(inbox.isEmpty) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: with reply") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = RpcMessage(null, "hi", null) + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveAndReplyMessage("hi") + } + + test("post: multiple threads") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + val numDroppedMessages = new AtomicInteger(0) + val inbox = new Inbox(endpointRef, endpoint) { + override def onDrop(message: InboxMessage): Unit = { + numDroppedMessages.incrementAndGet() + } + } + + val exitLatch = new CountDownLatch(10) + + for (_ <- 0 until 10) { + new Thread { + override def run(): Unit = { + for (_ <- 0 until 100) { + val message = OneWayMessage(null, "hi") + inbox.post(message) + } + exitLatch.countDown() + } + }.start() + } + // Try to process some messages + inbox.process(dispatcher) + inbox.stop() + // After `stop` is called, further messages will be dropped. However, while `stop` is called, + // some messages may be post to Inbox, so process them here. + inbox.process(dispatcher) + assert(inbox.isEmpty) + + exitLatch.await(30, TimeUnit.SECONDS) + + assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: Associated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(RemoteProcessConnected(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnConnectedMessage(remoteAddress) + } + + test("post: Disassociated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(RemoteProcessDisconnected(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnDisconnectedMessage(remoteAddress) + } + + test("post: AssociationError") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + val cause = new RuntimeException("Oops") + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala similarity index 60% rename from core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala rename to core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 1cd13d887c6f..56743ba650b4 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -15,23 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.rpc.netty -import java.net.InetSocketAddress +import org.apache.spark.SparkFunSuite -import org.apache.spark.util.Utils - -private[nio] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) -} +class NettyRpcAddressSuite extends SparkFunSuite { + test("toString") { + val addr = new RpcEndpointAddress("localhost", 12345, "test") + assert(addr.toString === "spark://test@localhost:12345") + } -private[nio] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) + test("toString for client mode") { + val addr = RpcEndpointAddress(null, "test") + assert(addr.toString === "spark-client://test") } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala new file mode 100644 index 000000000000..ce83087ec04d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rpc._ + +class NettyRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv( + conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf), + clientMode) + new NettyRpcEnvFactory().create(config) + } + + test("non-existent endpoint") { + val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val e = intercept[RpcEndpointNotFoundException] { + env.setupEndpointRef("test", env.address, "nonexist-endpoint") + } + assert(e.getMessage.contains(uri)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala new file mode 100644 index 000000000000..ebd6f700710b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.net.InetSocketAddress +import java.nio.ByteBuffer + +import io.netty.channel.Channel +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc._ + +class NettyRpcHandlerSuite extends SparkFunSuite { + + val env = mock(classOf[NettyRpcEnv]) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) + + test("receive") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) + } + + test("connectionTerminated") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.connectionTerminated(client) + + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) + verify(dispatcher, times(1)).postToAll( + RemoteProcessDisconnected(RpcAddress("localhost", 40000))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala new file mode 100644 index 000000000000..e0f474aa505c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark._ + +object AdaptiveSchedulingSuiteState { + var tasksRun = 0 + + def clear(): Unit = { + tasksRun = 0 + } +} + +class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext { + test("simple use of submitMapStage") { + try { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(1 to 3, 3).map { x => + AdaptiveSchedulingSuiteState.tasksRun += 1 + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep) + sc.submitMapStage(dep).get() + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + assert(shuffled.collect().toSet == Set((1, 1), (2, 2), (3, 3))) + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + } finally { + AdaptiveSchedulingSuiteState.clear() + } + } + + test("fetching multiple map output partitions per reduce") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x)) + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(3)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0, 2)) + assert(shuffled.partitions.length === 2) + assert(shuffled.glom().map(_.toSet).collect().toSet == Set(Set((0, 0), (1, 1)), Set((2, 2)))) + } + + test("fetching all map output partitions in one reduce") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x)) + // Also create lots of hash partitions so that some of them are empty + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(5)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0)) + assert(shuffled.partitions.length === 1) + assert(shuffled.collect().toSet == Set((0, 0), (1, 1), (2, 2))) + } + + test("more reduce tasks than map output partitions") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(0 to 2, 3).map(x => (x, x)) + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(3)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep, Array(0, 0, 0, 1, 1, 1, 2)) + assert(shuffled.partitions.length === 7) + assert(shuffled.collect().toSet == Set((0, 0), (1, 1), (2, 2))) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala new file mode 100644 index 000000000000..d8d818ceed45 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Arrays + +import org.apache.spark._ +import org.apache.spark.rdd.RDD + +/** + * A Partitioner that might group together one or more partitions from the parent. + * + * @param parent a parent partitioner + * @param partitionStartIndices indices of partitions in parent that should create new partitions + * in child (this should be an array of increasing partition IDs). For example, if we have a + * parent with 5 partitions, and partitionStartIndices is [0, 2, 4], we get three output + * partitions, corresponding to partition ranges [0, 1], [2, 3] and [4] of the parent partitioner. + */ +class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: Array[Int]) + extends Partitioner { + + @transient private lazy val parentPartitionMapping: Array[Int] = { + val n = parent.numPartitions + val result = new Array[Int](n) + for (i <- 0 until partitionStartIndices.length) { + val start = partitionStartIndices(i) + val end = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n + for (j <- start until end) { + result(j) = i + } + } + result + } + + override def numPartitions: Int = partitionStartIndices.size + + override def getPartition(key: Any): Int = { + parentPartitionMapping(parent.getPartition(key)) + } + + override def equals(other: Any): Boolean = other match { + case c: CoalescedPartitioner => + c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices) + case _ => + false + } +} + +private[spark] class CustomShuffledRDDPartition( + val index: Int, val startIndexInParent: Int, val endIndexInParent: Int) + extends Partition { + + override def hashCode(): Int = index +} + +/** + * A special ShuffledRDD that supports a ShuffleDependency object from outside and launching reduce + * tasks that read multiple map output partitions. + */ +class CustomShuffledRDD[K, V, C]( + var dependency: ShuffleDependency[K, V, C], + partitionStartIndices: Array[Int]) + extends RDD[(K, C)](dependency.rdd.context, Seq(dependency)) { + + def this(dep: ShuffleDependency[K, V, C]) = { + this(dep, (0 until dep.partitioner.numPartitions).toArray) + } + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override val partitioner = { + Some(new CoalescedPartitioner(dependency.partitioner, partitionStartIndices)) + } + + override def getPartitions: Array[Partition] = { + val n = dependency.partitioner.numPartitions + Array.tabulate[Partition](partitionStartIndices.length) { i => + val startIndex = partitionStartIndices(i) + val endIndex = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n + new CustomShuffledRDDPartition(i, startIndex, endIndex) + } + } + + override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = { + val part = p.asInstanceOf[CustomShuffledRDDPartition] + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context) + .read() + .asInstanceOf[Iterator[(K, C)]] + } + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86dff8fb577d..2869f0fde4c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.Properties + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -26,11 +28,11 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite -import org.apache.spark.executor.TaskMetrics class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -43,25 +45,52 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) case NonFatal(e) => onError(e) } } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + } /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable * so we can test that DAGScheduler does not try to execute RDDs locally. + * + * Optionally, one can pass in a list of locations to use as preferred locations for each task, + * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately + * because, in this test suite, it won't be the same as sc.env.mapOutputTracker. */ class MyRDD( sc: SparkContext, numPartitions: Int, dependencies: List[Dependency[_]], - locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { + locations: Seq[Seq[String]] = Nil, + @transient tracker: MapOutputTrackerMaster = null) + extends RDD[(Int, Int)](sc, dependencies) with Serializable { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { override def index: Int = i }).toArray - override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) locations(split.index) else Nil + + override def getPreferredLocations(partition: Partition): Seq[String] = { + if (locations.isDefinedAt(partition.index)) { + locations(partition.index) + } else if (tracker != null && dependencies.size == 1 && + dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) { + // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality + val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) + } else { + Nil + } + } + override def toString: String = "DAGSchedulerSuiteRDD " + id } @@ -133,11 +162,11 @@ class DAGSchedulerSuite val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null, conf, true) { - override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). getOrElse(Seq()) - }.toSeq + }.toIndexedSeq } override def removeExecutor(execId: String) { // don't need to propagate to the driver, which we don't have @@ -152,6 +181,14 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception) = { failure = exception } } + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { + val results = new HashMap[Int, Any] + var failure: Exception = null + override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def jobFailed(exception: Exception): Unit = { failure = exception } + } + before { sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() @@ -229,20 +266,30 @@ class DAGSchedulerSuite } } - /** Sends the rdd to the scheduler for scheduling and returns the job id. */ + /** Submits a job to the scheduler and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, + listener: JobListener = jobListener, + properties: Properties = null): Int = { + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) + jobId + } + + /** Submits a map stage to the scheduler and returns the job id. */ + private def submitMapStage( + shuffleDep: ShuffleDependency[_, _, _], listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener)) jobId } /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ @@ -260,13 +307,18 @@ class DAGSchedulerSuite test("zero split job") { var numResults = 0 + var failureReason: Option[Exception] = None val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception + override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1 + override def jobFailed(exception: Exception): Unit = { + failureReason = Some(exception) + } } val jobId = submit(new MyRDD(sc, 0, Nil), Array(), listener = fakeListener) assert(numResults === 0) cancel(jobId) + assert(failureReason.isDefined) + assert(failureReason.get.getMessage() === "Job 0 cancelled ") } test("run trivial job") { @@ -285,6 +337,15 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("equals and hashCode AccumulableInfo") { + val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) + val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + assert(accInfo1 !== accInfo2) + assert(accInfo2 === accInfo3) + assert(accInfo2.hashCode() === accInfo3.hashCode()) + } + test("cache location preferences w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil).cache() val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -325,7 +386,8 @@ class DAGSchedulerSuite */ test("getMissingParentStages should consider all ancestor RDDs' cache statuses") { val rddA = new MyRDD(sc, 1, Nil) - val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null))) + val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, new HashPartitioner(1))), + tracker = mapOutputTracker) val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache() val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC))) cacheLocations(rddC.id -> 0) = @@ -432,9 +494,9 @@ class DAGSchedulerSuite test("run trivial shuffle") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), @@ -448,13 +510,13 @@ class DAGSchedulerSuite test("run trivial shuffle with fetch failure") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -464,7 +526,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -473,15 +535,295 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + + // Helper function to validate state when creating tests for task failures + private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { + assert(stageAttempt.stageId === stageId) + assert(stageAttempt.stageAttemptId == attempt) + } + + + // Helper functions to extract commonly used code in Fetch Failure test cases + private def setupStageAbortTest(sc: SparkContext) { + sc.listenerBus.addListener(new EndListener()) + ended = false + jobResult = null + } + + // Create a new Listener to confirm that the listenerBus sees the JobEnd message + // when we abort the stage. This message will also be consumed by the EventLoggingListener + // so this will propagate up to the user. + var ended = false + var jobResult : JobResult = null + + class EndListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + ended = true + } + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * successfully. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param numShufflePartitions - The number of partitions in the next stage + */ + private def completeShuffleMapStageSuccessfully( + stageId: Int, + attemptIdx: Int, + numShufflePartitions: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { + case (task, idx) => + (Success, makeMapStatus("host" + ('A' + idx).toChar, numShufflePartitions)) + }.toSeq) + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * with all FetchFailure. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param shuffleDep - The shuffle dependency of the stage with a fetch failure + */ + private def completeNextStageWithFetchFailure( + stageId: Int, + attemptIdx: Int, + shuffleDep: ShuffleDependency[_, _, _]): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + }.toSeq) + } + + /** + * Common code to get the next result stage attempt, confirm it's the one we expect, and + * complete it with a success where we return 42. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + */ + private def completeNextResultStageWithSuccess( + stageId: Int, + attemptIdx: Int, + partitionToResult: Int => Int = _ => 42): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) + val taskResults = stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (Success, partitionToResult(idx)) + } + complete(stageAttempt, taskResults.toSeq) + } + + /** + * In this test, we simulate a job where many tasks in the same stage fail. We want to show + * that many fetch failures inside a single stage attempt do not trigger an abort + * on their own, but only when there are enough failing stage attempts. + */ + test("Single stage fetch failure should not abort the stage.") { + setupStageAbortTest(sc) + + val parts = 8 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, (0 until parts).toArray) + + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts) + + completeNextStageWithFetchFailure(1, 0, shuffleDep) + + // Resubmit and confirm that now all is well + scheduler.resubmitFailedStages() + + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Complete stage 0 and then stage 1 with a "42" + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts) + completeNextResultStageWithSuccess(1, 1) + + // Confirm job finished succesfully + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === (0 until parts).map { idx => idx -> 42 }.toMap) + assertDataStructuresEmpty() + } + + /** + * In this test we simulate a job failure where the first stage completes successfully and + * the second stage fails due to a fetch failure. Multiple successive fetch failures of a stage + * trigger an overall job abort to avoid endless retries. + */ + test("Multiple consecutive stage fetch failures should lead to job being aborted.") { + setupStageAbortTest(sc) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDep) + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } else { + // Stage should have been aborted and removed from running stages + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended) + jobResult match { + case JobFailed(reason) => + assert(reason.getMessage.contains("ResultStage 1 () has failed the maximum")) + case other => fail(s"expected JobFailed, not $other") + } + } + } + } + + /** + * In this test, we create a job with two consecutive shuffles, and simulate 2 failures for each + * shuffle fetch. In total In total, the job has had four failures overall but not four failures + * for a particular stage, and as such should not be aborted. + */ + test("Failures in different stages should not trigger an overall abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker) + submit(finalRdd, Array(0)) + + // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, + // stage 2 fails. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + } else { + completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) + + // Fail stage 2 + completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, + shuffleDepTwo) + } + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + } + + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) + + // Succeed stage2 with a "42" + completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() + } + + /** + * In this test we demonstrate that only consecutive failures trigger a stage abort. A stage may + * fail multiple times, succeed, then fail a few more times (because its run again by downstream + * dependencies). The total number of failed attempts for one stage will go over the limit, + * but that doesn't matter, since they have successes in the middle. + */ + test("Non-consecutive stage failures don't trigger abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker) + submit(finalRdd, Array(0)) + + // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + // Make each task in stage 0 success + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } + + // Rerun stage 0 and 1 to step through the task set + completeShuffleMapStageSuccessfully(0, 3, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 3, numShufflePartitions = 1) + + // Fail stage 2 so that stage 1 is resubmitted when we call scheduler.resubmitFailedStages() + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) + + scheduler.resubmitFailedStages() + + // Rerun stage 0 to step through the task set + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + + // Now again, fail stage 1 (up to MAX_FAILURES) but confirm that this doesn't trigger an abort + // since we succeeded in between. + completeNextStageWithFetchFailure(1, 4, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Next, succeed all and confirm output + // Rerun stage 0 + 1 + completeShuffleMapStageSuccessfully(0, 5, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 5, numShufflePartitions = 1) + + // Succeed stage 2 and verify results + completeNextResultStageWithSuccess(2, 1) + + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === Map(0 -> 42)) + } + test("trivial shuffle with multiple fetch failures") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -516,9 +858,9 @@ class DAGSchedulerSuite */ test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) val mapStageId = 0 @@ -527,6 +869,7 @@ class DAGSchedulerSuite } // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) complete(taskSets(0), Seq( @@ -583,9 +926,9 @@ class DAGSchedulerSuite test("extremely late fetch failures don't cause multiple concurrent attempts for " + "the same stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) def countSubmittedReduceStageAttempts(): Int = { @@ -646,31 +989,68 @@ class DAGSchedulerSuite test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) + // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA")) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) + + // now start completing some tasks in the shuffle map stage, under different hosts + // and epochs, and make sure scheduler updates its state correctly val taskSet = taskSets(0) + val shuffleStage = scheduler.stageIdToStage(taskSet.stageId).asInstanceOf[ShuffleMapStage] + assert(shuffleStage.numAvailableOutputs === 0) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 0) + + // should work because it's a non-failed host (so the available map outputs will increase) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostB", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a new epoch + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + + // should work because it's a new epoch, which will increase the number of available map + // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + + // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -678,8 +1058,8 @@ class DAGSchedulerSuite test("run shuffle with map stage failure") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) // Fail the map stage. This should cause the entire job to fail. @@ -695,6 +1075,214 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * Run two jobs, with a shared dependency. We simulate a fetch failure in the second job, which + * requires regenerating some outputs of the shared dependency. One key aspect of this test is + * that the second job actually uses a different stage for the shared dependency (a "skipped" + * stage). + */ + test("shuffle fetch failure in a reused shuffle dependency") { + // Run the first job successfully, which creates one shuffle dependency + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(0, 0, 2) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + // the scheduler now creates a new task set to regenerate the missing map output, but this time + // using a different stage, the "skipped" one + + // SPARK-9809 -- this stage is submitted without a task for each partition (because some of + // the shuffle map output is still available from stage 0); make sure we've still got internal + // accumulators setup + assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + } + + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -714,12 +1302,12 @@ class DAGSchedulerSuite */ test("failure of stage used by two jobs") { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) - val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) val shuffleMapRdd2 = new MyRDD(sc, 2, Nil) - val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) - val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1)) - val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2)) + val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2), tracker = mapOutputTracker) // We need to make our own listeners for this test, since by default submit uses the same // listener for all jobs, and here we want to capture the failure for each job separately. @@ -749,11 +1337,111 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) // blockManagerMaster.removeExecutor("exec-hostA") // pretend we were told hostA went away @@ -774,10 +1462,10 @@ class DAGSchedulerSuite test("recursive shuffle failures") { val shuffleOneRdd = new MyRDD(sc, 2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker) submit(finalRdd, Array(0)) // have the first stage complete normally complete(taskSets(0), Seq( @@ -803,14 +1491,14 @@ class DAGSchedulerSuite test("cached post-shuffle") { val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne), tracker = mapOutputTracker).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(1)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo), tracker = mapOutputTracker) submit(finalRdd, Array(0)) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - // complete stage 2 + // complete stage 0 complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) @@ -818,7 +1506,7 @@ class DAGSchedulerSuite complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // pretend stage 0 failed because hostA went down + // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: @@ -849,18 +1537,27 @@ class DAGSchedulerSuite assert(sc.parallelize(1 to 10, 2).count() === 10) } + /** + * The job will be failed on first task throwing a DAGSchedulerSuiteDummyException. + * Any subsequent task WILL throw a legitimate java.lang.UnsupportedOperationException. + * If multiple tasks, there exists a race condition between the SparkDriverExecutionExceptions + * and their differing causes as to which will represent result for job... + */ test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { val e = intercept[SparkDriverExecutionException] { + // Number of parallelized partitions implies number of tasks of job val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0, 1), + // For a robust test assertion, limit number of job tasks to 1; that is, + // if multiple RDD partitions, use id of any one partition, say, first partition id=0 + Seq(0), (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run commands + // Make sure we can still run commands on our SparkContext assert(sc.parallelize(1 to 10, 2).count() === 10) } @@ -912,9 +1609,9 @@ class DAGSchedulerSuite test("reduce tasks should be placed locally with map output") { // Create an shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) @@ -926,16 +1623,16 @@ class DAGSchedulerSuite assertLocations(reduceTaskSet, Seq(Seq("hostA"))) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("reduce task locality preferences should only include machines with largest map outputs") { val numMapTasks = 4 // Create an shuffleMapRdd with more partitions val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) val statuses = (1 to numMapTasks).map { i => @@ -950,7 +1647,29 @@ class DAGSchedulerSuite assertLocations(reduceTaskSet, Seq(hosts)) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() + } + + test("stages with both narrow and shuffle dependencies use narrow ones for locality") { + // Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join) + val rdd1 = new MyRDD(sc, 1, Nil) + val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB"))) + val shuffleDep = new ShuffleDependency(rdd1, new HashPartitioner(1)) + val narrowDep = new OneToOneDependency(rdd2) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) + + // Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostB"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() } test("Spark exceptions should include call site in stack trace") { @@ -968,6 +1687,244 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("catch errors in event loop") { + // this is a test of our testing framework -- make sure errors in event loop don't get ignored + + // just run some bad event that will throw an exception -- we'll give a null TaskEndReason + val rdd1 = new MyRDD(sc, 1, Nil) + submit(rdd1, Array(0)) + intercept[Exception] { + complete(taskSets(0), Seq( + (null, makeMapStatus("hostA", 1)))) + } + } + + test("simple map stage submission") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(results.size === 0) // No results yet + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage; it should directly do the reduce + submit(reduceRdd, Array(0)) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with reduce stage also depending on the data") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) + + // Submit the map stage by itself + submitMapStage(shuffleDep) + + // Submit a reduce job that depends on this map stage + submit(reduceRdd, Array(0)) + + // Complete tasks for the map stage + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + + // Complete tasks for the reduce stage + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with fetch failure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage, but where one reduce will fail a fetch + submit(reduceRdd, Array(0, 1)) + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch + // from, then TaskSet 3 will run the reduce stage + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + + // Run another reduce job without a failure; this should just work + submit(reduceRdd, Array(0, 1)) + complete(taskSets(4), Seq( + (Success, 44), + (Success, 45))) + assert(results === Map(0 -> 44, 1 -> 45)) + results.clear() + assertDataStructuresEmpty() + + // Resubmit the map stage; this should also just work + submitMapStage(shuffleDep) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + } + + /** + * In this test, we have three RDDs with shuffle dependencies, and we submit map stage jobs + * that are waiting on each one, as well as a reduce job on the last one. We test that all of + * these jobs complete even if there are some fetch failures in both shuffles. + */ + test("map stage submission with multiple shared stages and failures") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + val rdd3 = new MyRDD(sc, 2, List(dep2), tracker = mapOutputTracker) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + val listener3 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + submit(rdd3, Array(0, 1), listener = listener3) + + // Complete the first stage + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting the second stage, show a fetch failure + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + assert(listener2.results.size === 0) // Second stage listener should not have a result yet + + // Stage 0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result + + // Stage 1 should now be running as task set 3; make its first task succeed + assert(taskSets(3).stageId === 1) + complete(taskSets(3), Seq( + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(listener2.results.size === 1) + + // Finally, the reduce job should be running as task set 4; make it see a fetch failure, + // then make it run again and succeed + assert(taskSets(4).stageId === 2) + complete(taskSets(4), Seq( + (Success, 52), + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 + assert(taskSets(5).stageId === 1) + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) + complete(taskSets(6), Seq( + (Success, 53))) + assert(listener3.results === Map(0 -> 52, 1 -> 53)) + assertDataStructuresEmpty() + } + + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from that executor. We want to make sure the stage is not reported + * as done until all tasks have completed. + */ + test("map stage submission with executor failure late map task completions") { + val shuffleMapRdd = new MyRDD(sc, 3, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + + submitMapStage(shuffleDep) + + val oldTaskSet = taskSets(0) + runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Pretend host A was lost + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + + // Suppose we also get a completed event from task 1 on the same host; this should be ignored + runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // A completion from another task should work because it's a non-failed host + runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Now complete tasks in the second task set + val newTaskSet = taskSets(1) + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 1) // Map stage job should now finally be complete + assertDataStructuresEmpty() + + // Also test that a reduce stage using this shuffled data can immediately run + val reduceRDD = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + results.clear() + submit(reduceRDD, Array(0, 1)) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b3ca150195a5..f7e16af9d3a9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { +class FakeTask( + stageId: Int, + prefLocs: Seq[TaskLocation] = Nil) + extends Task[Int](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Int = 0 - override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index b8e466fab450..15c8de61b824 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.roaringbitmap.RoaringBitmap import scala.util.Random @@ -97,4 +98,34 @@ class MapStatusSuite extends SparkFunSuite { val buf = ser.newInstance().serialize(status) ser.newInstance().deserialize[MapStatus](buf) } + + test("RoaringBitmap: runOptimize succeeded") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 != 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 > size2) + assert(success) + } + + test("RoaringBitmap: runOptimize failed") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 == 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 === size2) + assert(!success) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 383855caefa2..f33324792495 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 000000000000..1ae5b030f083 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e5ecd4b7c261..7345508bfe99 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -84,7 +87,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { outputCommitCoordinator = spy(new OutputCommitCoordinator(conf, isDriver = true)) // Use Mockito.spy() to maintain the default infrastructure everywhere else. // This mocking allows us to control the coordinator responses in test cases. - SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, + SparkContext.numDriverCores(master), Some(outputCommitCoordinator)) } } // Use Mockito.spy() to maintain the default infrastructure everywhere else @@ -164,27 +168,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 - outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 730535ece787..f20d5be7c0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark.SparkException import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -36,6 +37,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -212,14 +228,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match i } - val d = sc.parallelize(0 to 1e4.toInt, 64).map(w) + val numSlices = 16 + val d = sc.parallelize(0 to 1e3.toInt, numSlices).map(w) d.count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (1) val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") - val d4 = d2.cogroup(d3, 64).map { case (k, (v1, v2)) => + val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => w(k) -> (v1.size, v2.size) } d4.setName("A Cogroup") @@ -258,8 +275,8 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be (128) - sm.localBlocksFetched should be (128) + sm.totalBlocksFetched should be (2*numSlices) + sm.localBlocksFetched should be (2*numSlices) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0L) } @@ -268,14 +285,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } @@ -365,10 +383,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + classOf[BasicJobCounter].getName) sc = new SparkContext(conf) - sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) - sc.listenerBus.listeners.collect { - case x: ListenerThatAcceptsSparkConf => x - }.size should be (1) + sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) } /** @@ -442,6 +459,21 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff..9fa885938291 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9201d1e1f328..d83d0aee4225 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import org.apache.spark.metrics.source.JvmSource @@ -57,8 +58,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String](0, 0, - sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { task.run(0, 0, null) } @@ -66,7 +68,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff52..bc72c3685e8c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index c2edd4c317d6..2afb595e6f10 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -237,4 +237,40 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L } } + test("tasks are not re-scheduled while executor loss reason is pending") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = Seq(new WorkerOffer("executor0", "host0", 1)) + val e1Offers = Seq(new WorkerOffer("executor1", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // mark executor0 as dead but pending fail reason + taskScheduler.executorLost("executor0", LossReasonPending) + + // offer some more resources on a different executor, nothing should change + val taskDescriptions2 = taskScheduler.resourceOffers(e1Offers).flatten + assert(0 === taskDescriptions2.length) + + // provide the actual loss reason for executor0 + taskScheduler.executorLost("executor0", SlaveLost("oops")) + + // executor0's tasks should have failed now that the loss reason is known, so offering more + // resources should make them be scheduled on the new executor. + val taskDescriptions3 = taskScheduler.resourceOffers(e1Offers).flatten + assert(1 === taskDescriptions3.length) + assert("executor1" === taskDescriptions3(0).executorId) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3abb99c4b2b5..ecc18fc6e15b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } @@ -136,7 +139,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) @@ -331,7 +334,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Now mark host2 as dead sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") + manager.executorLost("exec2", "host2", SlaveLost()) // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) @@ -501,13 +504,40 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") - manager.executorLost("execC", "host2") + manager.executorLost("execC", "host2", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) sched.removeExecutor("execD") - manager.executorLost("execD", "host1") + manager.executorLost("execD", "host1", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } + test("Executors exit for reason unrelated to currently running tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + sched.addExecutor("execA", "host1") + manager.executorAdded() + sched.addExecutor("execC", "host2") + manager.executorAdded() + assert(manager.resourceOffer("exec1", "host1", ANY).isDefined) + sched.removeExecutor("execA") + manager.executorLost( + "execA", + "host1", + ExecutorExited(143, false, "Terminated for reason unrelated to running tasks")) + assert(!sched.taskSetsFailed.contains(taskSet.id)) + assert(manager.resourceOffer("execC", "host2", ANY).isDefined) + sched.removeExecutor("execC") + manager.executorLost( + "execC", "host2", ExecutorExited(1, true, "Terminated due to issue with running tasks")) + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + test("test RACK_LOCAL tasks") { // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") @@ -718,8 +748,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") - manager.executorLost("execA", "host1") - manager.executorLost("execB.2", "host2") + manager.executorLost("execA", "host1", SlaveLost()) + manager.executorLost("execB.2", "host2", SlaveLost()) clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() @@ -733,9 +763,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, - Seq(HostTaskLocation("host1")), - Seq(HostTaskLocation("host2")), - Seq(HDFSCacheTaskLocation("host3"))) + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) @@ -750,6 +780,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.myLocalityLevels.sameElements(Array(ANY))) } + test("Test TaskLocation for different host type.") { + assert(TaskLocation("host1") === HostTaskLocation("host1")) + assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) + assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 5ed30f64d705..c4dc56003120 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import java.util +import java.util.Arrays +import java.util.Collection import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -41,6 +42,38 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + test("Use configured mesosExecutor.cores for ExecutorInfo") { + val mesosExecutorCores = 3 + val conf = new SparkConf + conf.set("spark.mesos.mesosExecutor.cores", mesosExecutorCores.toString) + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val resources = Arrays.asList( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) + // uri is null. + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") + val executorResources = executorInfo.getResourcesList + val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + + assert(cpus === mesosExecutorCores) + } + test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") @@ -61,7 +94,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val resources = List( + val resources = Arrays.asList( mesosSchedulerBackend.createResource("cpus", 4), mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. @@ -98,7 +131,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") val (execInfo, _) = backend.createExecutorInfo( - List(backend.createResource("cpus", 4)), "mockExecutor") + Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -179,7 +212,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -262,7 +295,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) .setHostname(s"host${id.toString}").build() - val mesosOffers = new java.util.ArrayList[Offer] mesosOffers.add(offer) @@ -279,7 +311,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -304,7 +336,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi assert(cpusDev.getName.equals("cpus")) assert(cpusDev.getScalar.getValue.equals(1.0)) assert(cpusDev.getRole.equals("dev")) - val executorResources = taskInfo.getExecutor.getResourcesList + val executorResources = taskInfo.getExecutor.getResourcesList.asScala assert(executorResources.exists { r => r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") }) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index b354914b6ffd..2eb43b731338 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.scheduler.cluster.mesos +import scala.language.reflectiveCalls + import org.apache.mesos.Protos.Value import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mock.MockitoSugar + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69..87f25e7245e1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 329a2b6dad83..20f45670bc2b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite { val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(serializer)) } + + test("Deserialize object containing a primitive Class as attribute") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + } +} + +private class ContainsPrimitiveClass extends Serializable { + val intClass = classOf[Int] + val longClass = classOf[Long] + val shortClass = classOf[Short] + val charClass = classOf[Char] + val doubleClass = classOf[Double] + val floatClass = classOf[Float] + val booleanClass = classOf[Boolean] + val byteClass = classOf[Byte] + val voidClass = classOf[Void] } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 23a1fdb0f500..9fcc22b608c6 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,16 +17,21 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} + +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -143,10 +148,40 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) + } + + test("Bug: SPARK-10251") { + val ser = new KryoSerializer(conf.clone.set("spark.kryo.registrationRequired", "true")) + .newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("ranges") { @@ -173,7 +208,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("asJavaIterable") { // Serialize a collection wrapped by asJavaIterable val ser = new KryoSerializer(conf).newInstance() - val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val a = ser.serialize(Seq(12345).asJava) val b = ser.deserialize[java.lang.Iterable[Int]](a) assert(b.iterator().next() === 12345) @@ -319,6 +354,28 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { + val dir = Utils.createTempDir() + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + input.close() + assert(ret == bitmap) + Utils.deleteRecursively(dir) + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 95% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index db718ecabbdb..26a372d6a905 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -115,7 +114,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => @@ -134,11 +133,11 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala new file mode 100644 index 000000000000..4d5f599fb12a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark._ + +case class KeyClass() + +case class ValueClass() + +case class CombinerClass() + +class ShuffleDependencySuite extends SparkFunSuite with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + + test("key, value, and combiner classes correct in shuffle dependency without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .groupByKey() + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!dep.mapSideCombine, "Test requires that no map-side aggregator is defined") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + } + + test("key, value, and combiner classes available in shuffle dependency with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .aggregateByKey(CombinerClass())({ case (a, b) => a }, { case (a, b) => a }) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.mapSideCombine && dep.aggregator.isDefined, "Test requires map-side aggregation") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == Some(classOf[CombinerClass].getName)) + } + + test("combineByKey null combiner class tag handled correctly") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .combineByKey((v: ValueClass) => v, + (c: AnyRef, v: ValueClass) => c, + (c1: AnyRef, c2: AnyRef) => c1) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == None) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala deleted file mode 100644 index f495b6a03795..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ /dev/null @@ -1,323 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.util.concurrent.CountDownLatch -import java.util.concurrent.atomic.AtomicInteger - -import org.mockito.Mockito._ -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.{SparkFunSuite, TaskContext} - -class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { - - val nextTaskAttemptId = new AtomicInteger() - - /** Launch a thread with the given body block and return it. */ - private def startThread(name: String)(body: => Unit): Thread = { - val thread = new Thread("ShuffleMemorySuite " + name) { - override def run() { - try { - val taskAttemptId = nextTaskAttemptId.getAndIncrement - val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) - when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) - TaskContext.setTaskContext(mockTaskContext) - body - } finally { - TaskContext.unset() - } - } - } - thread.start() - thread - } - - test("single task requesting memory") { - val manager = new ShuffleMemoryManager(1000L) - - assert(manager.tryToAcquire(100L) === 100L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(200L) === 100L) - assert(manager.tryToAcquire(100L) === 0L) - assert(manager.tryToAcquire(100L) === 0L) - - manager.release(500L) - assert(manager.tryToAcquire(300L) === 300L) - assert(manager.tryToAcquire(300L) === 200L) - - manager.releaseMemoryForThisTask() - assert(manager.tryToAcquire(1000L) === 1000L) - assert(manager.tryToAcquire(100L) === 0L) - } - - test("two threads requesting full memory") { - // Two threads request 500 bytes first, wait for each other to get it, and then request - // 500 more; we should immediately return 0 as both are now at 1 / N - - val manager = new ShuffleMemoryManager(1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 500L) - assert(state.t2Result1 === 500L) - assert(state.t1Result2 === 0L) - assert(state.t2Result2 === 0L) - } - - - test("tasks cannot grow past 1 / N") { - // Two tasks request 250 bytes first, wait for each other to get it, and then request - // 500 more; we should only grant 250 bytes to each of them on this second request - - val manager = new ShuffleMemoryManager(1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 250L) - assert(state.t2Result1 === 250L) - assert(state.t1Result2 === 250L) - assert(state.t2Result2 === 250L) - } - - test("tasks can block to get at least 1 / 2N memory") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases 250 bytes, which should then be granted to t2. Further requests - // by t2 will return false right away because it now has 1 / 2N of the memory. - - val manager = new ShuffleMemoryManager(1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result = -1L - var t2Result2 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise - Thread.sleep(300) - manager.release(250L) - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val result = manager.tryToAcquire(250L) - val endTime = System.currentTimeMillis() - state.synchronized { - state.t2Result = result - // A second call should return 0 because we're now already at 1 / 2N - state.t2Result2 = manager.tryToAcquire(100L) - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both threads should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released 250 - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result === 250L, "t2 could not allocate memory") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - assert(state.t2Result2 === 0L, "t1 got extra memory the second time") - } - } - - test("releaseMemoryForThisTask") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases all its memory. t2 should now be able to grab all the memory. - - val manager = new ShuffleMemoryManager(1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result1 = -1L - var t2Result2 = -1L - var t2Result3 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other task blocks for some time otherwise - Thread.sleep(300) - manager.releaseMemoryForThisTask() - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val r1 = manager.tryToAcquire(500L) - val endTime = System.currentTimeMillis() - val r2 = manager.tryToAcquire(500L) - val r3 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.t2Result2 = r2 - state.t2Result3 = r3 - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both tasks should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released all of it - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") - assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") - assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - } - } - - test("tasks should not be granted a negative size") { - val manager = new ShuffleMemoryManager(1000L) - manager.tryToAcquire(700L) - - val latch = new CountDownLatch(1) - startThread("t1") { - manager.tryToAcquire(300L) - latch.countDown() - } - latch.await() // Wait until `t1` calls `tryToAcquire` - - val granted = manager.tryToAcquire(300L) - assert(0 === granted, "granted is negative") - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e18..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index cc7342f1ecd7..d3b1b2b620b4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -42,25 +43,42 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private var taskMetrics: TaskMetrics = _ - private var shuffleWriteMetrics: ShuffleWriteMetrics = _ private var tempDir: File = _ private var outputFile: File = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] - private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) - private val serializer: Serializer = new JavaSerializer(conf) + private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) - shuffleWriteMetrics = new ShuffleWriteMetrics taskMetrics = new TaskMetrics - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) MockitoAnnotations.initMocks(this) + shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( + shuffleId = 0, + numMaps = 2, + dependency = dependency + ) + when(dependency.partitioner).thenReturn(new HashPartitioner(7)) + when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -72,13 +90,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( - args(0).asInstanceOf[BlockId], args(1).asInstanceOf[File], args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], compressStream = identity, syncWrites = false, - args(4).asInstanceOf[ShuffleWriteMetrics] + args(4).asInstanceOf[ShuffleWriteMetrics], + blockId = args(0).asInstanceOf[BlockId] ) } }) @@ -108,18 +126,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(Iterator.empty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === 0) + writer.write(Iterator.empty) + writer.stop( /* success = */ true) + assert(writer.getPartitionLengths.sum === 0) assert(outputFile.exists()) assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === 0) assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) @@ -130,17 +150,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(records) + writer.write(records) + writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.sum === outputFile.length()) assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) @@ -149,14 +171,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) intercept[SparkException] { - writer.insertAll((0 until 100000).iterator.map(i => { + writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -164,7 +187,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte })) } assert(temporaryFilesCreated.nonEmpty) - writer.stop() + writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala similarity index 80% rename from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 6727934d8c7c..8744a072cb3f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe +package org.apache.spark.shuffle.sort import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { +class SortShuffleManagerSuite extends SparkFunSuite with Matchers { - import UnsafeShuffleManager.canUseUnsafeShuffle + import SortShuffleManager.canUseSerializedShuffle private class RuntimeExceptionAnswer extends Answer[Object] { override def answer(invocation: InvocationOnMock): Object = { @@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { dep } - test("supported shuffle dependencies") { + test("supported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, @@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) when(rangePartitioner.numPartitions).thenReturn(2) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = rangePartitioner, serializer = kryo, keyOrdering = None, @@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // Shuffles with key orderings are supported as long as no aggregator is specified - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), @@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { } - test("unsupported shuffle dependencies") { + test("unsupported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) val java = Some(new JavaSerializer(new SparkConf())) // We only support serializers that support object relocation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = java, keyOrdering = None, @@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles with more than 16 million output partitions - assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + // The serialized shuffle path do not support shuffles with more than 16 million output + // partitions, due to a limitation in its sorter implementation. + assert(!canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner( + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1), serializer = kryo, keyOrdering = None, aggregator = None, @@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // We do not support shuffles that perform aggregation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), mapSideCombine = false ))) - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala deleted file mode 100644 index 34b4984f12c0..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import org.mockito.Mockito._ - -import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} - -class SortShuffleWriterSuite extends SparkFunSuite { - - import SortShuffleWriter._ - - test("conditions for bypassing merge-sort") { - val conf = new SparkConf(loadDefaults = false) - val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high - assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) - assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) - - // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala deleted file mode 100644 index 6351539e91e9..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.io.File - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.Utils - -class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. - - override def beforeAll() { - conf.set("spark.shuffle.manager", "tungsten-sort") - // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort - // shuffle records. - conf.set("spark.shuffle.memoryFraction", "0.5") - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new KryoSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the old SortShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new JavaSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala new file mode 100644 index 000000000000..88817dccf349 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} + +class AllStagesResourceSuite extends SparkFunSuite { + + def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { + val tasks = new HashMap[Long, TaskUIData] + taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => + tasks(idx.toLong) = new TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + } + + val stageUiData = new StageUIData() + stageUiData.taskData = tasks + val status = StageStatus.ACTIVE + val stageInfo = new StageInfo( + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) + + stageData.firstTaskLaunchedTime + } + + test("firstTaskLaunchedTime when there are no tasks") { + val result = getFirstTaskLaunchTime(Seq()) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks but none launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks and some launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) + assert(result == Some(new Date(1449255596000L))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 0f5ba46f69c2..6e3f500e15dc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,10 +26,11 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ +import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -38,30 +39,32 @@ import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private val conf = new SparkConf(false) - var rpcEnv: RpcEnv = null - var master: BlockManagerMaster = null - val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) - val shuffleManager = new HashShuffleManager(conf) + private val conf = new SparkConf(false).set("spark.app.id", "test") + private var rpcEnv: RpcEnv = null + private var master: BlockManagerMaster = null + private val securityMgr = new SecurityManager(conf) + private val mapOutputTracker = new MapOutputTrackerMaster(conf) + private val shuffleManager = new HashShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. - val allStores = new ArrayBuffer[BlockManager] + private val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer", "1m") - val serializer = new KryoSerializer(conf) + private val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. - implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val store = new BlockManager(name, rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(store.memoryStore) store.initialize("app-id") allStores += store store @@ -258,8 +261,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1) + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) + memManager.setMemoryStore(failableStore.memoryStore) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f480fd107a0c..53991d8a1aed 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -30,10 +30,11 @@ import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -44,9 +45,10 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var store: BlockManager = null var store2: BlockManager = null + var store3: BlockManager = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") @@ -65,11 +67,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) - manager.initialize("app-id") - manager + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(blockManager.memoryStore) + blockManager.initialize("app-id") + blockManager } override def beforeEach(): Unit = { @@ -99,6 +103,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store2.stop() store2 = null } + if (store3 != null) { + store3.stop() + store3 = null + } rpcEnv.shutdown() rpcEnv.awaitTermination() rpcEnv = null @@ -443,6 +451,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { + val origTimeoutOpt = conf.getOption("spark.network.timeout") + try { + conf.set("spark.network.timeout", "2s") + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store3 = makeBlockManager(8000, "executor3") + val list1 = List(new Array[Byte](4000)) + store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + var list1Get = store.getRemoteBytes("list1") + assert(list1Get.isDefined, "list1Get expected to be fetched") + // block manager exit + store2.stop() + store2 = null + list1Get = store.getRemoteBytes("list1") + // get `list1` block + assert(list1Get.isDefined, "list1Get expected to be fetched") + store3.stop() + store3 = null + // exception throw because there is no locations + intercept[BlockFetchException] { + list1Get = store.getRemoteBytes("list1") + } + } finally { + origTimeoutOpt match { + case Some(t) => conf.set("spark.network.timeout", t) + case None => conf.remove("spark.network.timeout") + } + } + } + test("in-memory LRU storage") { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) @@ -782,10 +822,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val memoryManager = new StaticMemoryManager( + conf, + maxOnHeapExecutionMemory = Long.MaxValue, + maxStorageMemory = 1200, + numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, - 0) + new JavaSerializer(conf), conf, memoryManager, mapOutputTracker, + shuffleManager, transfer, securityMgr, 0) + memoryManager.setMemoryStore(store.memoryStore) // The put should fail since a1 is not serializable. class UnserializableClass @@ -796,7 +842,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") + assert(store.getSingle("a1").isEmpty, "a1 should not be in store") } } @@ -1006,14 +1052,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.currentUnrollMemory === 0) assert(memoryStore.currentUnrollMemoryForThisTask === 0) + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + memoryStore.reserveUnrollMemoryForThisTask( + TestBlockId(""), memory, new ArrayBuffer[(BlockId, BlockStatus)]) + } + // Reserve - memoryStore.reserveUnrollMemoryForThisTask(100) + assert(reserveUnrollMemoryForThisTask(100)) assert(memoryStore.currentUnrollMemoryForThisTask === 100) - memoryStore.reserveUnrollMemoryForThisTask(200) + assert(reserveUnrollMemoryForThisTask(200)) assert(memoryStore.currentUnrollMemoryForThisTask === 300) - memoryStore.reserveUnrollMemoryForThisTask(500) + assert(reserveUnrollMemoryForThisTask(500)) assert(memoryStore.currentUnrollMemoryForThisTask === 800) - memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(!reserveUnrollMemoryForThisTask(1000000)) assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release memoryStore.releaseUnrollMemoryForThisTask(100) @@ -1021,9 +1072,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE memoryStore.releaseUnrollMemoryForThisTask(100) assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(reserveUnrollMemoryForThisTask(4400)) assert(memoryStore.currentUnrollMemoryForThisTask === 5000) - memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(!reserveUnrollMemoryForThisTask(20000)) assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again memoryStore.releaseUnrollMemoryForThisTask(1000) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 66af6e1a7974..7c19531c1880 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -20,7 +20,6 @@ import java.io.File import org.scalatest.BeforeAndAfterEach -import org.apache.spark.SparkConf import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer @@ -41,8 +40,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("verify write metrics") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -63,8 +62,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("verify write metrics on revert") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -86,8 +85,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("Reopening a closed block writer") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.open() writer.close() @@ -99,8 +98,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -115,8 +114,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("commitAndClose() should be idempotent") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -133,8 +132,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("revertPartialWritesAndClose() should be idempotent") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -151,8 +150,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("fileSegment() can only be called after commitAndClose() has been called") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) for (i <- 1 to 1000) { writer.write(i, i) } @@ -165,8 +164,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { test("commitAndClose() without ever opening or writing") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.commitAndClose() assert(writer.fileSegment().length === 0) } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index ac6fec56bbf4..cc50289c7b3e 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} - +import org.apache.spark.util.SparkConfWithEnv /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. @@ -45,20 +45,10 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 assert(!new File("/NONEXISTENT_DIR").exists()) - // SPARK_LOCAL_DIRS is a valid directory: - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: - val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH") + val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) + .set("spark.local.dir", "/NONEXISTENT_PATH") assert(new File(Utils.getLocalDir(conf)).exists()) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cf8bd8ae6962..828153bdbfc4 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 000000000000..86699e7f5695 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf(false).set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + // Avoid setting anything; it should be displayed by default + val conf3 = new SparkConf(false) + val html3 = renderStagePage(conf3).toString().toLowerCase + assert(html3.contains(targetString)) + } + + test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + // verify min/25/50/75/max show task value not cumulative values + assert(html.contains("10.0 b" * 5)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) + val peakExecutionMemory = 10 + taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, + Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + } + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 3aa672f8b713..ceecfd665bf8 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} -import scala.collection.JavaConversions._ +import scala.io.Source import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler @@ -340,15 +340,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B // The completed jobs table should have two rows. The first row will be the most recent job: val firstRow = find(cssSelector("tbody tr")).get.underlying val firstRowColumns = firstRow.findElements(By.tagName("td")) - firstRowColumns(0).getText should be ("1") - firstRowColumns(4).getText should be ("1/1 (2 skipped)") - firstRowColumns(5).getText should be ("8/8 (16 skipped)") + firstRowColumns.get(0).getText should be ("1") + firstRowColumns.get(4).getText should be ("1/1 (2 skipped)") + firstRowColumns.get(5).getText should be ("8/8 (16 skipped)") // The second row is the first run of the job, where nothing was skipped: val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying val secondRowColumns = secondRow.findElements(By.tagName("td")) - secondRowColumns(0).getText should be ("0") - secondRowColumns(4).getText should be ("3/3") - secondRowColumns(5).getText should be ("24/24") + secondRowColumns.get(0).getText should be ("0") + secondRowColumns.get(4).getText should be ("3/3") + secondRowColumns.get(5).getText should be ("24/24") } } } @@ -501,8 +501,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expJobInfo(idx)._1) description should include (expJobInfo(idx)._2) @@ -546,8 +546,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expStageInfo(idx)._1) description should include (expStageInfo(idx)._2) @@ -603,6 +603,44 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + test("job stages should have expected dotfile under DAG visualization") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + rdd.count() + + val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]")) + + val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]")) + + val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]")) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } @@ -620,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala new file mode 100644 index 000000000000..dd8d5ec27f87 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Elem + +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite { + import UIUtils._ + + test("makeDescription") { + verify( + """test
    text """, + test text , + "Correctly formatted text with only anchors and relative links should generate HTML" + ) + + verify( + """test """, + {"""test """}, + "Badly formatted text should make the description be treated as a streaming instead of HTML" + ) + + verify( + """test text """, + {"""test text """}, + "Non-relative links should make the description be treated as a string instead of HTML" + ) + + verify( + """test""", + {"""test"""}, + "Non-anchor elements should make the description be treated as a string instead of HTML" + ) + + verify( + """test text """, + test text , + baseUrl = "base", + errorMsg = "Base URL should be prepended to html links" + ) + } + + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + + private def verify( + desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { + val generated = makeDescription(desc, baseUrl) + assert(generated.sameElements(expected), + s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated") + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f35..e02f5a1b20fe 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,10 +240,10 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, - ExecutorLostFailure("0"), + ExecutorLostFailure("0", true, Some("Induced failure")), UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 61601016e005..0af4b6098bb0 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + .set("spark.rpc.askTimeout", "5s") + .set("spark.rpc.lookupTimeout", "5s") val securityManagerBad = new SecurityManager(slaveConf) val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) try { slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) fail("should receive either ActorNotFound or TimeoutException") diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index dde95f377843..1939ce5c743b 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,7 +151,8 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure("100")) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) + testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) // BlockId @@ -162,8 +163,13 @@ class JsonProtocolSuite extends SparkFunSuite { testBlockId(StreamBlockId(1, 2L)) } + /* ============================== * + | Backward compatibility tests | + * ============================== */ + test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) @@ -294,10 +300,10 @@ class JsonProtocolSuite extends SparkFunSuite { test("ExecutorLostFailure backward compatibility") { // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. - val executorLostFailure = ExecutorLostFailure("100") + val executorLostFailure = ExecutorLostFailure("100", true, Some("Induced failure")) val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) .removeField({ _._1 == "Executor ID" }) - val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true, Some("Induced failure")) assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -332,14 +338,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) } - test("RDDInfo backward compatibility (scope, parent IDs)") { - // Prior to Spark 1.4.0, RDDInfo did not have the "Scope" and "Parent IDs" properties - val rddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), Some(new RDDOperationScope("fable"))) + test("RDDInfo backward compatibility (scope, parent IDs, callsite)") { + // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 + // "Callsite" was introduced in Spark 1.6.0 + val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), + "callsite", Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) - val expectedRddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq.empty, scope = None) + .removeField({ _._1 == "Callsite"}) + val expectedRddInfo = new RDDInfo( + 1, "one", 100, StorageLevel.NONE, Seq.empty, "", scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -351,6 +360,26 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + + test("AccumulableInfo backward compatibility") { + // "Internal" property of AccumulableInfo were added after 1.5.1. + val accumulableInfo = makeAccumulableInfo(1) + val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + .removeField({ _._1 == "Internal" }) + val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) + assert(false === oldInfo.internal) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -498,7 +527,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) @@ -576,8 +605,16 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) + case (ExecutorLostFailure(execId1, exit1CausedByApp, reason1), + ExecutorLostFailure(execId2, exit2CausedByApp, reason2)) => assert(execId1 === execId2) + assert(exit1CausedByApp === exit2CausedByApp) + assert(reason1 === reason2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } @@ -683,7 +720,7 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7), a.toString) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -703,15 +740,15 @@ class JsonProtocolSuite extends SparkFunSuite { val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative) val (acc1, acc2, acc3) = - (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3)) + (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true)) taskInfo.accumulables += acc1 taskInfo.accumulables += acc2 taskInfo.accumulables += acc3 taskInfo } - private def makeAccumulableInfo(id: Int): AccumulableInfo = - AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id) + private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo = + AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is @@ -792,13 +829,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -824,6 +863,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", + | "Callsite": "101", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -846,13 +886,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | } @@ -882,19 +924,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | } @@ -922,19 +967,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | } @@ -968,19 +1016,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1054,19 +1105,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1137,19 +1191,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1209,6 +1266,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", + | "Callsite": "1", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1231,13 +1289,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1250,6 +1310,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", + | "Callsite": "2", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1267,6 +1328,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1289,13 +1351,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1308,6 +1372,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1325,6 +1390,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1342,6 +1408,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1364,13 +1431,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1383,6 +1452,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1400,6 +1470,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1417,6 +1488,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", + | "Callsite": "6", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1434,6 +1506,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", + | "Callsite": "7", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1456,13 +1529,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index d3d464e84ffd..8b53d4f14a6a 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -19,9 +19,14 @@ package org.apache.spark.util import java.net.URLClassLoader +import scala.collection.JavaConverters._ + +import org.scalatest.Matchers +import org.scalatest.Matchers._ + import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} -class MutableURLClassLoaderSuite extends SparkFunSuite { +class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val urls2 = List(TestUtils.createJarWithClasses( classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), @@ -32,6 +37,12 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { toStringValue = "1", classpathUrls = urls2)).toArray + val fileUrlsChild = List(TestUtils.createJarWithFiles(Map( + "resource1" -> "resource1Contents-child", + "resource2" -> "resource2Contents"))).toArray + val fileUrlsParent = List(TestUtils.createJarWithFiles(Map( + "resource1" -> "resource1Contents-parent"))).toArray + test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) @@ -68,6 +79,33 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { } } + test("default JDK classloader get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new URLClassLoader(fileUrlsChild, parentLoader) + assert(classLoader.getResources("resource1").asScala.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + } + + test("parent first get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new MutableURLClassLoader(fileUrlsChild, parentLoader) + assert(classLoader.getResources("resource1").asScala.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + } + + test("child first get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new ChildFirstURLClassLoader(fileUrlsChild, parentLoader) + + val res1 = classLoader.getResources("resource1").asScala.toList + assert(res1.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + + res1.map(scala.io.Source.fromURL(_).mkString) should contain inOrderOnly + ("resource1Contents-child", "resource1Contents-parent") + } + + test("driver sets context class loader in local mode") { // Test the case where the driver program sets a context classloader and then runs a job // in local mode. This is what happens when ./spark-submit is called with "local" as the diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index c58db5e606f7..60fb7abb66d3 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -45,7 +45,7 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 20550178fb1b..101610e38014 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,6 +60,12 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } +class DummyClass8 extends KnownSizeEstimation { + val x: Int = 0 + + override def estimatedSize: Long = 2015 +} + class SizeEstimatorSuite extends SparkFunSuite with BeforeAndAfterEach @@ -214,4 +220,10 @@ class SizeEstimatorSuite // Class should be 32 bytes on s390x if recognised as 64 bit platform assertResult(32)(SizeEstimator.estimate(new DummyClass7)) } + + test("SizeEstimation can provide the estimated size") { + // DummyClass8 provides its size estimation. + assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) + assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) + } } diff --git a/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala new file mode 100644 index 000000000000..ddd5edf4f739 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkConf + +/** + * Customized SparkConf that allows env variables to be overridden. + */ +class SparkConfWithEnv(env: Map[String, String]) extends SparkConf(false) { + override def getenv(name: String): String = { + env.get(name).getOrElse(super.getenv(name)) + } + + override def clone: SparkConf = { + new SparkConfWithEnv(env).setAll(getAll) + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7f..92ae03896752 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,11 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random + +import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkFunSuite @@ -58,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite { } } + test("newDaemonCachedThreadPool") { + val maxThreadNumber = 10 + val startThreadsLatch = new CountDownLatch(maxThreadNumber) + val latch = new CountDownLatch(1) + val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "ThreadUtilsSuite-newDaemonCachedThreadPool", + maxThreadNumber, + keepAliveSeconds = 2) + try { + for (_ <- 1 to maxThreadNumber) { + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + startThreadsLatch.countDown() + latch.await(10, TimeUnit.SECONDS) + } + }) + } + startThreadsLatch.await(10, TimeUnit.SECONDS) + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 0) + + // Submit a new task and it should be put into the queue since the thread number reaches the + // limitation + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + latch.await(10, TimeUnit.SECONDS) + } + }) + + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 1) + + latch.countDown() + eventually(timeout(10.seconds)) { + // All threads should be stopped after keepAliveSeconds + assert(cachedThreadPool.getActiveCount === 0) + assert(cachedThreadPool.getPoolSize === 0) + } + } finally { + cachedThreadPool.shutdownNow() + } + } + test("sameThread") { val callerThreadName = Thread.currentThread().getName() val f = Future { @@ -66,4 +112,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8f7e402d5f2a..fdb51d440eff 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -384,7 +384,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("hdfs:/root/spark.jar", "hdfs:/root/spark.jar") assertResolves("hdfs:///root/spark.jar#app.jar", "hdfs:/root/spark.jar#app.jar") assertResolves("spark.jar", s"file:$cwd/spark.jar") - assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar%23app.jar") + assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar#app.jar") assertResolves("path to/file.txt", s"file:$cwd/path%20to/file.txt") if (Utils.isWindows) { assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt") @@ -414,10 +414,10 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("file:/jar1,file:/jar2", "file:/jar1,file:/jar2") assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5,path to/jar6", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4%23jar5,file:$cwd/path%20to/jar6") + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5,file:$cwd/path%20to/jar6") if (Utils.isWindows) { assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4""", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi,file:/C:/path%20to/jar4") } } @@ -720,4 +720,29 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) } + + test("isDynamicAllocationEnabled") { + val conf = new SparkConf() + assert(Utils.isDynamicAllocationEnabled(conf) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "false")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "true")) === true) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "1")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "0")) === true) + } + + test("encodeFileNameToURIRawPath") { + assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") + assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") + assert(Utils.encodeFileNameToURIRawPath("abc:xyz") === "abc:xyz") + } + + test("decodeFileNameInURI") { + assert(Utils.decodeFileNameInURI(new URI("files:///abc/xyz")) === "xyz") + assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") + assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala deleted file mode 100644 index 05306f408847..000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.nio.ByteBuffer - -import org.scalatest.Matchers._ - -import org.apache.spark.SparkFunSuite - -class ChainedBufferSuite extends SparkFunSuite { - test("write and read at start") { - // write from start of source array - val buffer = new ChainedBuffer(8) - buffer.capacity should be (0) - verifyWriteAndRead(buffer, 0, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 0, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 0, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 0, 0, 0, 8) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 0, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 0, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at middle") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 3) - - // write from start of source array - verifyWriteAndRead(buffer, 3, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 3, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 3, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 3, 0, 0, 5) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 3, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 3, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at later buffer") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 11) - - // write from start of source array - verifyWriteAndRead(buffer, 11, 0, 0, 4) - buffer.capacity should be (16) - - // write from middle of source array - verifyWriteAndRead(buffer, 11, 5, 0, 4) - buffer.capacity should be (16) - - // read to middle of target array - verifyWriteAndRead(buffer, 11, 0, 5, 4) - buffer.capacity should be (16) - - // write up to border - verifyWriteAndRead(buffer, 11, 0, 0, 5) - buffer.capacity should be (16) - - // expand into second buffer - verifyWriteAndRead(buffer, 11, 0, 0, 12) - buffer.capacity should be (24) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 11, 0, 0, 28) - buffer.capacity should be (40) - } - - - // Used to make sure we're writing different bytes each time - var rangeStart = 0 - - /** - * @param buffer The buffer to write to and read from. - * @param offsetInBuffer The offset to write to in the buffer. - * @param offsetInSource The offset in the array that the bytes are written from. - * @param offsetInTarget The offset in the array to read the bytes into. - * @param length The number of bytes to read and write - */ - def verifyWriteAndRead( - buffer: ChainedBuffer, - offsetInBuffer: Int, - offsetInSource: Int, - offsetInTarget: Int, - length: Int): Unit = { - val source = new Array[Byte](offsetInSource + length) - (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) - buffer.write(offsetInBuffer, source, offsetInSource, length) - val target = new Array[Byte](offsetInTarget + length) - buffer.read(offsetInBuffer, target, offsetInTarget, length) - ByteBuffer.wrap(source, offsetInSource, length) should be - (ByteBuffer.wrap(target, offsetInTarget, length)) - - rangeStart += 100 - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 9c362f0de707..dc3185a6d505 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -21,16 +21,22 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.MemoryTestingUtils class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { + import TestUtils.{assertNotSpilled, assertSpilled} + private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = buf1 ++= buf2 - private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( - createCombiner[T], mergeValue[T], mergeCombiners[T]) + private def createExternalMap[T] = { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T], context = context) + } private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -46,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { conf } - test("simple insert") { + test("single insert insert") { val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] - - // Single insert map.insert(1, 10) - var it = map.iterator + val it = map.iterator assert(it.hasNext) val kv = it.next() assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) assert(!it.hasNext) + sc.stop() + } - // Multiple insert + test("multiple insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) map.insert(2, 20) map.insert(3, 30) - it = map.iterator + val it = map.iterator assert(it.hasNext) assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( (1, ArrayBuffer[Int](10)), @@ -141,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] + val nullInt = null.asInstanceOf[Int] map.insert(1, 5) map.insert(2, 6) map.insert(3, 7) - assert(map.size === 3) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( - (1, Seq[Int](5)), - (2, Seq[Int](6)), - (3, Seq[Int](7)) - )) - - // Null keys - val nullInt = null.asInstanceOf[Int] + map.insert(4, nullInt) map.insert(nullInt, 8) - assert(map.size === 4) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( + map.insert(nullInt, nullInt) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted)) + assert(result === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)), - (nullInt, Seq[Int](8)) + (4, Seq[Int](nullInt)), + (nullInt, Seq[Int](nullInt, 8)) )) - // Null values - map.insert(4, nullInt) - map.insert(nullInt, nullInt) - assert(map.size === 5) - val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result === Set[(Int, Set[Int])]( - (1, Set[Int](5)), - (2, Set[Int](6)), - (3, Set[Int](7)), - (4, Set[Int](nullInt)), - (nullInt, Set[Int](nullInt, 8)) - )) sc.stop() } @@ -242,56 +235,53 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { * If a compression codec is provided, use it. Otherwise, do not compress spills. */ private def testSimpleSpilling(codec: Option[String] = None): Unit = { + val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "hash") // avoid using external sorter + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - // reduceByKey - should spill ~8 times - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max).collect() - assert(resultA.length === 50000) - resultA.foreach { case (k, v) => - assert(v === k * 2 + 1, s"Value for $k was wrong: expected ${k * 2 + 1}, got $v") + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) }.reduceByKey(math.max).collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } } - // groupByKey - should spill ~17 times - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey().collect() - assert(resultB.length === 25000) - resultB.foreach { case (i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - assert(seq.toSet === expected, - s"Value for $i was wrong: expected $expected, got ${seq.toSet}") + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size).map { i => (i / 2, i) }.groupByKey().collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } } - // cogroup - should spill ~7 times - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2).collect() - assert(resultC.length === 10000) - resultC.foreach { case (i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet === Set[Int](0)) - assert(seq2.toSet === Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet === Set[Int](1)) - assert(seq2.toSet === Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet === Set[Int](5000)) - assert(seq2.toSet === Set[Int]()) - case 9999 => - assert(seq1.toSet === Set[Int](9999)) - assert(seq2.toSet === Set[Int]()) - case _ => + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") } } + sc.stop() } test("spilling with hash collisions") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] @@ -315,11 +305,12 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(w1.hashCode === w2.hashCode) } - map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i))) + map.insertAll((1 to size).iterator.map(_.toString).map(i => (i, i))) collisionPairs.foreach { case (w1, w2) => map.insert(w1, w2) map.insert(w2, w1) } + assert(map.numSpills > 0, "map did not spill") // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -334,23 +325,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2.equals(expectedValue)) count += 1 } - assert(count === 100000 + collisionPairs.size * 2) + assert(count === size + collisionPairs.size * 2) sc.stop() } test("spilling with many hash collisions") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.0001") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val map = + new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). for (i <- 1 to 10) { - for (j <- 1 to 10000) { + for (j <- 1 to size) { map.insert(FixedHashObject(j, j % 2), 1) } } + assert(map.numSpills > 0, "map did not spill") val it = map.iterator var count = 0 @@ -359,18 +354,20 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2 === 10) count += 1 } - assert(count === 10000) + assert(count === size) sc.stop() } test("spilling with hash collisions using the Int.MaxValue key") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] - (1 to 100000).foreach { i => map.insert(i, i) } + (1 to size).foreach { i => map.insert(i, i) } map.insert(Int.MaxValue, Int.MaxValue) + assert(map.numSpills > 0, "map did not spill") val it = map.iterator while (it.hasNext) { @@ -381,15 +378,17 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } test("spilling with null keys and values") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] - map.insertAll((1 to 100000).iterator.map(i => (i, i))) + map.insertAll((1 to size).iterator.map(i => (i, i))) map.insert(null.asInstanceOf[Int], 1) map.insert(1, null.asInstanceOf[Int]) map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int]) + assert(map.numSpills > 0, "map did not spill") val it = map.iterator while (it.hasNext) { @@ -399,4 +398,24 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val spillThreshold = 1000 + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + assertNotSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold / 2, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + assertSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold * 3, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 986cd8623d14..d7b2d07a4005 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,555 +17,102 @@ package org.apache.spark.util.collection -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.memory.MemoryTestingUtils +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { - private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { - val conf = new SparkConf(loadDefaults) - if (kryo) { - conf.set("spark.serializer", classOf[KryoSerializer].getName) - } else { - // Make the Java serializer write a reset instruction (TC_RESET) after each object to test - // for a bug we had with bytes written past the last object in a batch (SPARK-2792) - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", classOf[JavaSerializer].getName) - } - conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") - // Ensure that we actually have multiple batches per spill file - conf.set("spark.shuffle.spill.batchSize", "10") - conf - } - - test("empty data stream with kryo ser") { - emptyDataStream(createSparkConf(false, true)) - } - - test("empty data stream with java ser") { - emptyDataStream(createSparkConf(false, false)) - } - - def emptyDataStream(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - - // Both aggregator and ordering - val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - assert(sorter.iterator.toSeq === Seq()) - sorter.stop() - - // Only aggregator - val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), None, None) - assert(sorter2.iterator.toSeq === Seq()) - sorter2.stop() - - // Only ordering - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - assert(sorter3.iterator.toSeq === Seq()) - sorter3.stop() - - // Neither aggregator nor ordering - val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), None, None) - assert(sorter4.iterator.toSeq === Seq()) - sorter4.stop() - } - - test("few elements per partition with kryo ser") { - fewElementsPerPartition(createSparkConf(false, true)) - } - - test("few elements per partition with java ser") { - fewElementsPerPartition(createSparkConf(false, false)) - } - - def fewElementsPerPartition(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - val elements = Set((1, 1), (2, 2), (5, 5)) - val expected = Set( - (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), - (5, Set((5, 5))), (6, Set())) - - // Both aggregator and ordering - val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), Some(ord), None) - sorter.insertAll(elements.iterator) - assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter.stop() - - // Only aggregator - val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), None, None) - sorter2.insertAll(elements.iterator) - assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter2.stop() - - // Only ordering - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) - sorter3.insertAll(elements.iterator) - assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter3.stop() - - // Neither aggregator nor ordering - val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) - sorter4.insertAll(elements.iterator) - assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter4.stop() - } - - test("empty partitions with spilling with kryo ser") { - emptyPartitionsWithSpilling(createSparkConf(false, true)) - } - - test("empty partitions with spilling with java ser") { - emptyPartitionsWithSpilling(createSparkConf(false, false)) - } - - def emptyPartitionsWithSpilling(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) - sorter.insertAll(elements) - assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled - val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) - assert(iter.next() === (0, Nil)) - assert(iter.next() === (1, List((1, 1)))) - assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) - assert(iter.next() === (3, Nil)) - assert(iter.next() === (4, Nil)) - assert(iter.next() === (5, List((5, 5)))) - assert(iter.next() === (6, Nil)) - sorter.stop() - } - - test("spilling in local cluster with kryo ser") { - // Load defaults, otherwise SPARK_HOME is not found - testSpillingInLocalCluster(createSparkConf(true, true)) - } - - test("spilling in local cluster with java ser") { - // Load defaults, otherwise SPARK_HOME is not found - testSpillingInLocalCluster(createSparkConf(true, false)) - } - - def testSpillingInLocalCluster(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - - // reduceByKey - should spill ~8 times - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max).collect() - assert(resultA.length == 50000) - resultA.foreach { case(k, v) => - if (v != k * 2 + 1) { - fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") - } - } - // groupByKey - should spill ~17 times - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey().collect() - assert(resultB.length == 25000) - resultB.foreach { case(i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - if (seq.toSet != expected) { - fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") - } - } - - // cogroup - should spill ~7 times - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2).collect() - assert(resultC.length == 10000) - resultC.foreach { case(i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet == Set[Int](0)) - assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet == Set[Int](1)) - assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet == Set[Int](5000)) - assert(seq2.toSet == Set[Int]()) - case 9999 => - assert(seq1.toSet == Set[Int](9999)) - assert(seq2.toSet == Set[Int]()) - case _ => - } - } +class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { + import TestUtils.{assertNotSpilled, assertSpilled} - // larger cogroup - should spill ~7 times - val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val resultD = rddD1.cogroup(rddD2).collect() - assert(resultD.length == 5000) - resultD.foreach { case(i, (seq1, seq2)) => - val expected = Set(i * 2, i * 2 + 1) - if (seq1.toSet != expected) { - fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") - } - if (seq2.toSet != expected) { - fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") - } - } + testWithMultipleSer("empty data stream")(emptyDataStream) - // sortByKey - should spill ~17 times - val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultE = rddE.sortByKey().collect().toSeq - assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) - } + testWithMultipleSer("few elements per partition")(fewElementsPerPartition) - test("spilling in local cluster with many reduce tasks with kryo ser") { - spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true)) - } + testWithMultipleSer("empty partitions with spilling")(emptyPartitionsWithSpilling) - test("spilling in local cluster with many reduce tasks with java ser") { - spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false)) + // Load defaults, otherwise SPARK_HOME is not found + testWithMultipleSer("spilling in local cluster", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 2) } - def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) - - // reduceByKey - should spill ~4 times per executor - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max _, 100).collect() - assert(resultA.length == 50000) - resultA.foreach { case(k, v) => - if (v != k * 2 + 1) { - fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") - } - } - - // groupByKey - should spill ~8 times per executor - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey(100).collect() - assert(resultB.length == 25000) - resultB.foreach { case(i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - if (seq.toSet != expected) { - fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") - } - } - - // cogroup - should spill ~4 times per executor - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2, 100).collect() - assert(resultC.length == 10000) - resultC.foreach { case(i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet == Set[Int](0)) - assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet == Set[Int](1)) - assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet == Set[Int](5000)) - assert(seq2.toSet == Set[Int]()) - case 9999 => - assert(seq1.toSet == Set[Int](9999)) - assert(seq2.toSet == Set[Int]()) - case _ => - } - } - - // larger cogroup - should spill ~4 times per executor - val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val resultD = rddD1.cogroup(rddD2).collect() - assert(resultD.length == 5000) - resultD.foreach { case(i, (seq1, seq2)) => - val expected = Set(i * 2, i * 2 + 1) - if (seq1.toSet != expected) { - fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") - } - if (seq2.toSet != expected) { - fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") - } - } - - // sortByKey - should spill ~8 times per executor - val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultE = rddE.sortByKey().collect().toSeq - assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) + testWithMultipleSer("spilling in local cluster with many reduce tasks", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 100) } test("cleanup of intermediate files in sorter") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val ord = implicitly[Ordering[Int]] - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - - val sorter2 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) - sorter2.stop() - assert(diskBlockManager.getAllBlocks().length === 0) + cleanupIntermediateFilesInSorter(withFailures = false) } - test("cleanup of intermediate files in sorter if there are errors") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val ord = implicitly[Ordering[Int]] - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - intercept[SparkException] { - sorter.insertAll((0 until 120000).iterator.map(i => { - if (i == 119990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) + test("cleanup of intermediate files in sorter with failures") { + cleanupIntermediateFilesInSorter(withFailures = true) } test("cleanup of intermediate files in shuffle") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val data = sc.parallelize(0 until 100000, 2).map(i => (i, i)) - assert(data.reduceByKey(_ + _).count() === 100000) - - // After the shuffle, there should be only 4 files on disk: our two map output files and - // their index files. All other intermediate files should've been deleted. - assert(diskBlockManager.getAllFiles().length === 4) - } - - test("cleanup of intermediate files in shuffle with errors") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val data = sc.parallelize(0 until 100000, 2).map(i => { - if (i == 99990) { - throw new Exception("Intentional failure") - } - (i, i) - }) - intercept[SparkException] { - data.reduceByKey(_ + _).count() - } - - // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index. - // All other files (map 2's output and intermediate merge files) should've been deleted. - assert(diskBlockManager.getAllFiles().length === 2) + cleanupIntermediateFilesInShuffle(withFailures = false) } - test("no partial aggregation or sorting with kryo ser") { - noPartialAggregationOrSorting(createSparkConf(false, true)) + test("cleanup of intermediate files in shuffle with failures") { + cleanupIntermediateFilesInShuffle(withFailures = true) } - test("no partial aggregation or sorting with java ser") { - noPartialAggregationOrSorting(createSparkConf(false, false)) + testWithMultipleSer("no sorting or partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = false) } - def noPartialAggregationOrSorting(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) + testWithMultipleSer("no sorting or partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = true) } - test("partial aggregation without spill with kryo ser") { - partialAggregationWithoutSpill(createSparkConf(false, true)) + testWithMultipleSer("sorting, no partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = false) } - test("partial aggregation without spill with java ser") { - partialAggregationWithoutSpill(createSparkConf(false, false)) + testWithMultipleSer("sorting, no partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = true) } - def partialAggregationWithoutSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) + testWithMultipleSer("partial aggregation, no sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = false) } - test("partial aggregation with spill, no ordering with kryo ser") { - partialAggregationWIthSpillNoOrdering(createSparkConf(false, true)) + testWithMultipleSer("partial aggregation, no sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = true) } - test("partial aggregation with spill, no ordering with java ser") { - partialAggregationWIthSpillNoOrdering(createSparkConf(false, false)) + testWithMultipleSer("partial aggregation and sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = false) } - def partialAggregationWIthSpillNoOrdering(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) - } - - test("partial aggregation with spill, with ordering with kryo ser") { - partialAggregationWithSpillWithOrdering(createSparkConf(false, true)) + testWithMultipleSer("partial aggregation and sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = true) } - - test("partial aggregation with spill, with ordering with java ser") { - partialAggregationWithSpillWithOrdering(createSparkConf(false, false)) - } - - def partialAggregationWithSpillWithOrdering(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - - // avoid combine before spill - sorter.insertAll((0 until 50000).iterator.map(i => (i , 2 * i))) - sorter.insertAll((0 until 50000).iterator.map(i => (i, 2 * i + 1))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) - } - - test("sorting without aggregation, no spill with kryo ser") { - sortingWithoutAggregationNoSpill(createSparkConf(false, true)) - } - - test("sorting without aggregation, no spill with java ser") { - sortingWithoutAggregationNoSpill(createSparkConf(false, false)) - } - - def sortingWithoutAggregationNoSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100).iterator.map(i => (i, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq - val expected = (0 until 3).map(p => { - (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) - }).toSeq - assert(results === expected) - } - - test("sorting without aggregation, with spill with kryo ser") { - sortingWithoutAggregationWithSpill(createSparkConf(false, true)) - } - - test("sorting without aggregation, with spill with java ser") { - sortingWithoutAggregationWithSpill(createSparkConf(false, false)) - } - - def sortingWithoutAggregationWithSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq - val expected = (0 until 3).map(p => { - (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) - }).toSeq - assert(results === expected) - } + testWithMultipleSer("sort without breaking sorting contracts", loadDefaults = true)( + sortWithoutBreakingSortingContracts) test("spilling with hash collisions") { - val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) - : ArrayBuffer[String] = buffer1 ++= buffer2 + def mergeCombiners( + buffer1: ArrayBuffer[String], + buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) val collisionPairs = Seq( ("Aa", "BB"), // 2112 @@ -587,10 +134,11 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(w1.hashCode === w2.hashCode) } - val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ + val toInsert = (1 to size).iterator.map(_.toString).map(s => (s, s)) ++ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) sorter.insertAll(toInsert) + assert(sorter.numSpills > 0, "sorter did not spill") // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -605,22 +153,22 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2.equals(expectedValue)) count += 1 } - assert(count === 100000 + collisionPairs.size * 2) + assert(count === size + collisionPairs.size * 2) } test("spilling with many hash collisions") { - val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.0001") + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) - val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) - + val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). - val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) + val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) sorter.insertAll(toInsert.iterator) - + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator var count = 0 while (it.hasNext) { @@ -628,13 +176,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2 === 10) count += 1 } - assert(count === 10000) + assert(count === size) } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -643,11 +193,11 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) - val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - + val sorter = + new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None) sorter.insertAll( - (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) - + (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator while (it.hasNext) { // Should not throw NoSuchElementException @@ -656,9 +206,11 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -669,14 +221,14 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), ("1", null.asInstanceOf[String]), (null.asInstanceOf[String], null.asInstanceOf[String]) )) - + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator while (it.hasNext) { // Should not throw NullPointerException @@ -684,17 +236,307 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } - test("sort without breaking sorting contracts with kryo ser") { - sortWithoutBreakingSortingContracts(createSparkConf(true, true)) + /* ============================= * + | Helper test utility methods | + * ============================= */ + + private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + if (kryo) { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + } else { + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", classOf[JavaSerializer].getName) + } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") + conf + } + + /** + * Run a test multiple times, each time with a different serializer. + */ + private def testWithMultipleSer( + name: String, + loadDefaults: Boolean = false)(body: (SparkConf => Unit)): Unit = { + test(name + " with kryo ser") { + body(createSparkConf(loadDefaults, kryo = true)) + } + test(name + " with java ser") { + body(createSparkConf(loadDefaults, kryo = false)) + } } - test("sort without breaking sorting contracts with java ser") { - sortWithoutBreakingSortingContracts(createSparkConf(true, false)) + /* =========================================== * + | Helper methods that contain the test body | + * =========================================== */ + + private def emptyDataStream(conf: SparkConf) { + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter.iterator.toSeq === Seq()) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + context, Some(agg), Some(new HashPartitioner(3)), None, None) + assert(sorter2.iterator.toSeq === Seq()) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter3.iterator.toSeq === Seq()) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(3)), None, None) + assert(sorter4.iterator.toSeq === Seq()) + sorter4.stop() } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.01") + private def fewElementsPerPartition(conf: SparkConf) { conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + sorter.insertAll(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + context, Some(agg), Some(new HashPartitioner(7)), None, None) + sorter2.insertAll(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(7)), Some(ord), None) + sorter3.insertAll(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(7)), None, None) + sorter4.insertAll(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter4.stop() + } + + private def emptyPartitionsWithSpilling(conf: SparkConf) { + val size = 1000 + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val ord = implicitly[Ordering[Int]] + val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) + + val sorter = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(7)), Some(ord), None) + sorter.insertAll(elements) + assert(sorter.numSpills > 0, "sorter did not spill") + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 1000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int) { + val size = 5000 + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .reduceByKey(math.max _, numReduceTasks) + .collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } + } + + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .groupByKey(numReduceTasks) + .collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } + } + + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2, numReduceTasks).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") + } + } + + assertSpilled(sc, "sortByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .sortByKey(numPartitions = numReduceTasks) + .collect() + val expected = (0 until size).map { i => (i / 2, i) }.toArray + assert(result.length === size) + result.zipWithIndex.foreach { case ((k, _), i) => + val (expectedKey, _) = expected(i) + assert(k === expectedKey, s"Value for $i was wrong: expected $expectedKey, got $k") + } + } + } + + private def cleanupIntermediateFilesInSorter(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + val diskBlockManager = sc.env.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + val expectedSize = if (withFailures) size - 1 else size + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = new ExternalSorter[Int, Int, Int]( + context, None, Some(new HashPartitioner(3)), Some(ord), None) + if (withFailures) { + intercept[SparkException] { + sorter.insertAll((0 until size).iterator.map { i => + if (i == size - 1) { throw new SparkException("intentional failure") } + (i, i) + }) + } + } else { + sorter.insertAll((0 until size).iterator.map(i => (i, i))) + } + assert(sorter.iterator.toSet === (0 until expectedSize).map(i => (i, i)).toSet) + assert(sorter.numSpills > 0, "sorter did not spill") + assert(diskBlockManager.getAllFiles().nonEmpty, "sorter did not spill") + sorter.stop() + assert(diskBlockManager.getAllFiles().isEmpty, "spilled files were not cleaned up") + } + + private def cleanupIntermediateFilesInShuffle(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + val diskBlockManager = sc.env.blockManager.diskBlockManager + val data = sc.parallelize(0 until size, 2).map { i => + if (withFailures && i == size - 1) { + throw new SparkException("intentional failure") + } + (i, i) + } + + assertSpilled(sc, "test shuffle cleanup") { + if (withFailures) { + intercept[SparkException] { + data.reduceByKey(_ + _).count() + } + // After the shuffle, there should be only 2 files on disk: the output of task 1 and + // its index. All other files (map 2's output and intermediate merge files) should + // have been deleted. + assert(diskBlockManager.getAllFiles().length === 2) + } else { + assert(data.reduceByKey(_ + _).count() === size) + // After the shuffle, there should be only 4 files on disk: the output of both tasks + // and their indices. All intermediate merge files should have been deleted. + assert(diskBlockManager.getAllFiles().length === 4) + } + } + } + + private def basicSorterTest( + conf: SparkConf, + withPartialAgg: Boolean, + withOrdering: Boolean, + withSpilling: Boolean) { + val size = 1000 + if (withSpilling) { + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + } + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val agg = + if (withPartialAgg) { + Some(new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)) + } else { + None + } + val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = + new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None) + sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) + if (withSpilling) { + assert(sorter.numSpills > 0, "sorter did not spill") + } else { + assert(sorter.numSpills === 0, "sorter spilled") + } + val results = sorter.partitionedIterator.map { case (p, vs) => (p, vs.toSet) }.toSet + val expected = (0 until 3).map { p => + var v = (0 until size).map { i => (i / 4, i) }.filter { case (k, _) => k % 3 == p }.toSet + if (withPartialAgg) { + v = v.groupBy(_._1).mapValues { s => s.map(_._2).sum }.toSet + } + (p, v.toSet) + }.toSet + assert(results === expected) + } + + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { + val size = 100000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. @@ -707,17 +549,19 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } - val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + val testData = Array.tabulate(size) { _ => rand.nextInt().toString } + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter1 = new ExternalSorter[String, String, String]( - None, None, Some(wrongOrdering), None) + context, None, None, Some(wrongOrdering), None) val thrown = intercept[IllegalArgumentException] { sorter1.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter1.numSpills > 0, "sorter did not spill") sorter1.iterator } - assert(thrown.getClass() === classOf[IllegalArgumentException]) - assert(thrown.getMessage().contains("Comparison method violates its general contract")) + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert(thrown.getMessage.contains("Comparison method violates its general contract")) sorter1.stop() // Using aggregation and external spill to make sure ExternalSorter using @@ -731,8 +575,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) sorter2.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter2.numSpills > 0, "sorter did not spill") // To validate the hash ordering of key var minKey = Int.MinValue @@ -743,5 +588,26 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val spillThreshold = 1000 + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter without spilling") { + assertNotSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold / 2, 2).repartition(100).count() + } + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter with spilling") { + assertSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold * 3, 2).repartition(100).count() + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala deleted file mode 100644 index 3b67f6206495..000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import com.google.common.io.ByteStreams - -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.Mockito.RETURNS_SMART_NULLS -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.Matchers._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.DiskBlockObjectWriter - -class PartitionedSerializedPairBufferSuite extends SparkFunSuite { - test("OrderedInputStream single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - - val bytes = ByteStreams.toByteArray(buffer.orderedInputStream) - - val baos = new ByteArrayOutputStream() - val stream = serializerInstance.serializeStream(baos) - stream.writeObject(10) - stream.writeObject(struct) - stream.close() - - baos.toByteArray should be (bytes) - } - - test("insert single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (1) - elements.head should be (((4, 10), struct)) - } - - test("insert multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (3) - elements(0) should be (((4, 2), struct2)) - elements(1) should be (((5, 3), struct3)) - elements(2) should be (((6, 1), struct1)) - } - - test("write single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - stream.readObject[AnyRef]() should be (10) - stream.readObject[AnyRef]() should be (struct) - } - - test("write multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (5) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (6) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - val iter = stream.asIterator - iter.next() should be (2) - iter.next() should be (struct2) - iter.next() should be (3) - iter.next() should be (struct3) - iter.next() should be (1) - iter.next() should be (struct1) - assert(!iter.hasNext) - } - - def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { - val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) - val baos = new ByteArrayOutputStream() - when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - val args = invocationOnMock.getArguments - val bytes = args(0).asInstanceOf[Array[Byte]] - val offset = args(1).asInstanceOf[Int] - val length = args(2).asInstanceOf[Int] - baos.write(bytes, offset, length) - } - }) - (writer, baos) - } -} - -case class SomeStruct(str: String, num: Int) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 26a2e96edaaa..0326ed70b5ed 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -55,6 +55,44 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + test("Binary prefix comparator") { + + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compare(y(i)) + if (res != 0) return res + } + x.length - y.length + } + + def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = { + val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x) + val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y) + val prefixComparisonResult = + PrefixComparators.BINARY.compare(s1Prefix, s2Prefix) + assert( + (prefixComparisonResult == 0) || + (prefixComparisonResult < 0 && compareBinary(x, y) < 0) || + (prefixComparisonResult > 0 && compareBinary(x, y) > 0)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => + testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + } + forAll { (s1: String, s2: String) => + testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + } + } + test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index d26667bf720c..a5b50fce5c0a 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -65,4 +65,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } + + test ("hashSeed has random bits throughout") { + val totalBitCount = (0 until 10).map { seed => + val hashed = XORShiftRandom.hashSeed(seed) + val bitCount = java.lang.Long.bitCount(hashed) + // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we + // don't have all 0s or 1s in the high bits + bitCount should be > 20 + bitCount should be < 44 + bitCount + }.sum + // and over all the seeds, very close to equal numbers of 0s & 1s + totalBitCount should be > (32 * 10 - 30) + totalBitCount should be < (32 * 10 + 30) + } } diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 38becda0eae9..f72f8c653a26 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.4 \ +$ SCALA_VERSION=2.10.5 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 0b7069f6e116..27d1dd784ce2 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" RELEASE_VERSION = "1.1.1" -SCALA_VERSION = "2.10.4" +SCALA_VERSION = "2.10.5" SCALA_BINARY_VERSION = "2.10" # Do not set these diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh deleted file mode 100755 index 4311c8c9e4ca..000000000000 --- a/dev/create-release/create-release.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Quick-and-dirty automation of making maven and binary releases. Not robust at all. -# Publishes releases to Maven and packages/copies binary release artifacts. -# Expects to be run in a totally empty directory. -# -# Options: -# --skip-create-release Assume the desired release tag already exists -# --skip-publish Do not publish to Maven central -# --skip-package Do not package and upload binary artifacts -# Would be nice to add: -# - Send output to stderr and have useful logging in stdout - -# Note: The following variables must be set before use! -ASF_USERNAME=${ASF_USERNAME:-pwendell} -ASF_PASSWORD=${ASF_PASSWORD:-XXX} -GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} -GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} -# Allows publishing under a different version identifier than -# was present in the actual release sources (e.g. rc-X) -PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} -NEXT_VERSION=${NEXT_VERSION:-1.2.1} -RC_NAME=${RC_NAME:-rc2} - -M2_REPO=~/.m2/repository -SPARK_REPO=$M2_REPO/org/apache/spark -NEXUS_ROOT=https://repository.apache.org/service/local/staging -NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi -JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} - -set -e - -GIT_TAG=v$RELEASE_VERSION-$RC_NAME - -if [[ ! "$@" =~ --skip-create-release ]]; then - echo "Creating release commit and publishing to Apache repository" - # Artifact publishing - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ - -b $GIT_BRANCH - pushd spark - export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - - # Create release commits and push them to github - # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build - # or before we coin the release commit. This helps avoid races where - # other people add commits to this branch while we are in the middle of building. - cur_ver="${RELEASE_VERSION}-SNAPSHOT" - rel_ver="${RELEASE_VERSION}" - next_ver="${NEXT_VERSION}-SNAPSHOT" - - old="^\( \{2,4\}\)${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.11 - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.10 - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 8aaa250bd7e2..db9c680a4bad 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -178,13 +178,16 @@ def populate(issue_type, components): author_info[author][issue_type].add(component) # Find issues and components associated with this commit for issue in issues: - jira_issue = jira_client.issue(issue) - jira_type = jira_issue.fields.issuetype.name - jira_type = translate_issue_type(jira_type, issue, warnings) - jira_components = [translate_component(c.name, _hash, warnings)\ - for c in jira_issue.fields.components] - all_components = set(jira_components + commit_components) - populate(jira_type, all_components) + try: + jira_issue = jira_client.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, _hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + except Exception as e: + print "Unexpected error:", e # For docs without an associated JIRA, manually add it ourselves if is_docs(title) and not issues: populate("documentation", commit_components) @@ -223,7 +226,8 @@ def populate(issue_type, components): # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672 if author in invalid_authors and invalid_authors[author]: author = author + "/" + "/".join(invalid_authors[author]) - line = " * %s -- %s" % (author, contribution) + #line = " * %s -- %s" % (author, contribution) + line = author contributors_file.write(line + "\n") contributors_file.close() print "Contributors list is successfully written to %s!" % contributors_file_name diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index e462302f2842..3563fe3cc3c0 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -138,3 +138,30 @@ lee19 - Lee lockwobr - Brian Lockwood navis - Navis Ryu pparkkin - Paavo Parkkinen +HyukjinKwon - Hyukjin Kwon +JDrit - Joseph Batchik +JuhongPark - Juhong Park +KaiXinXiaoLei - KaiXinXIaoLei +NamelessAnalyst - NamelessAnalyst +alyaxey - Alex Slusarenko +baishuo - Shuo Bai +fe2s - Oleksiy Dyagilev +felixcheung - Felix Cheung +feynmanliang - Feynman Liang +josepablocam - Jose Cambronero +kai-zeng - Kai Zeng +mosessky - mosessky +msannell - Michael Sannella +nishkamravi2 - Nishkam Ravi +noel-smith - Noel Smith +petz2000 - Patrick Baier +qiansl127 - Shilei Qian +rahulpalamuttam - Rahul Palamuttam +rowan000 - Rowan Chattaway +sarutak - Kousuke Saruta +sethah - Seth Hendrickson +small-wang - Wang Wei +stanzhai - Stan Zhai +tien-dungle - Tien-Dung Le +xuchenCN - Xu Chen +zhangjiajin - Zhang JiaJin diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 000000000000..cb79e9eba06e --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,326 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +git clean -d -f -x +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + + # Get maven home set by MVN + MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` + + ./make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.11 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES clean deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 000000000000..b0a3374becc6 --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 51ab25a6a5bd..7f152b7f5355 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -24,7 +24,11 @@ try: from jira.client import JIRA - from jira.exceptions import JIRAError + # Old versions have JIRAError in exceptions package, new (0.5+) in utils. + try: + from jira.exceptions import JIRAError + except ImportError: + from jira.utils import JIRAError except ImportError: print "This tool requires the jira-python library" print "Install using 'sudo pip install jira'" diff --git a/bagel/src/test/resources/log4j.properties b/dev/lint-java old mode 100644 new mode 100755 similarity index 61% rename from bagel/src/test/resources/log4j.properties rename to dev/lint-java index edbecdae9209..fe8ab83d562d --- a/bagel/src/test/resources/log4j.properties +++ b/dev/lint-java @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -15,13 +17,14 @@ # limitations under the License. # -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/lint-python b/dev/lint-python index 575dbb0ae321..0b97213ae3df 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/lint-r b/dev/lint-r index 7d5f4cd31153..bfda0bca15eb 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -28,3 +28,14 @@ if ! type "Rscript" > /dev/null; then fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" + +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME" | awk '{print $1}'` +if [ "$NUM_LINES" = "0" ] ; then + lint_status=0 + echo "lintr checks passed." +else + lint_status=1 + echo "lintr checks failed." +fi + +exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index 48bd6246096a..999eef571b82 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -17,8 +17,14 @@ argv <- commandArgs(TRUE) SPARK_ROOT_DIR <- as.character(argv[1]) +LOCAL_LIB_LOC <- file.path(SPARK_ROOT_DIR, "R", "lib") -# Installs lintr from Github. +# Checks if SparkR is installed in a local directory. +if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} + +# Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") @@ -27,9 +33,5 @@ if ("lintr" %in% row.names(installed.packages()) == FALSE) { library(lintr) library(methods) library(testthat) -if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { - stop("You should install SparkR in a local directory with `R/install-dev.sh`.") -} - path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ad4b76695c9f..bf1a000f4679 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -159,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) @@ -304,24 +300,24 @@ def resolve_jira_issues(title, merge_branches, comment): def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix - Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" + Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX][MLLIB] Issue" >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") - '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' + '[SPARK-5821][SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") - '[SPARK-4123] [PROJECT INFRA] [WIP] Show new dependencies added in pull requests' + '[SPARK-4123][PROJECT INFRA][WIP] Show new dependencies added in pull requests' >>> standardize_jira_ref("[MLlib] Spark 5954: Top by key") - '[SPARK-5954] [MLLIB] Top by key' + '[SPARK-5954][MLLIB] Top by key' >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' >>> standardize_jira_ref("SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") - '[SPARK-1146] [WIP] Vagrant support for Spark' + '[SPARK-1146][WIP] Vagrant support for Spark' >>> standardize_jira_ref("SPARK-1032. If Yarn app fails before registering, app master stays aroun...") '[SPARK-1032] If Yarn app fails before registering, app master stays aroun...' >>> standardize_jira_ref("[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.") - '[SPARK-6250] [SPARK-6146] [SPARK-5911] [SQL] Types are now reserved words in DDL parser.' + '[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.' >>> standardize_jira_ref("Additional information for users building from source code") 'Additional information for users building from source code' """ @@ -329,7 +325,7 @@ def standardize_jira_ref(text): components = [] # If the string is compliant, no need to process any further - if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): + if (re.search(r'^\[SPARK-[0-9]{3,6}\](\[[A-Z0-9_\s,]+\] )+\S+', text)): return text # Extract JIRA ref(s): @@ -352,7 +348,7 @@ def standardize_jira_ref(text): text = pattern.search(text).groups()[0] # Assemble full text (JIRA ref(s), module(s), remaining text) - clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() + clean_text = ''.join(jira_refs).strip() + ''.join(components).strip() + " " + text.strip() # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) diff --git a/dev/mima b/dev/mima index 2952fa65d42f..d5baffc6ef8a 100755 --- a/dev/mima +++ b/dev/mima @@ -38,7 +38,7 @@ generate_mima_ignore() { # it did not process the new classes (which are in assembly jar). generate_mima_ignore -export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" +export SPARK_CLASSPATH="$(build/sbt "export oldDeps/fullClasspath" | tail -n1)" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" generate_mima_ignore diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c4d39d95d589..e79accf9e987 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,226 +22,7 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -# Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -source "$FWDIR/dev/run-tests-codes.sh" - -COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" -PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" - -# Important Environment Variables -# --- -# $ghprbActualCommit -#+ This is the hash of the most recent commit in the PR. -#+ The merge-base of this and master is the commit from which the PR was branched. -# $sha1 -#+ If the patch merges cleanly, this is a reference to the merge commit hash -#+ (e.g. "origin/pr/2606/merge"). -#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. -#+ The merge-base of this and master in the case of a clean merge is the most recent commit -#+ against master. - -COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" -# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( -SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" - -# format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 180m) -TESTS_TIMEOUT="175m" - -# Array to capture all tests to run on the pull request. These tests are held under the -#+ dev/tests/ directory. -# -# To write a PR test: -#+ * the file must reside within the dev/tests directory -#+ * be an executable bash script -#+ * accept three arguments on the command line, the first being the Github PR long commit -#+ hash, the second the Github SHA1 hash, and the final the current PR hash -#+ * and, lastly, return string output to be included in the pr message output that will -#+ be posted to Github -PR_TESTS=( - "pr_merge_ability" - "pr_public_classes" -# DISABLED (pwendell) "pr_new_dependencies" -) - -function post_message () { - local message=$1 - local data="{\"body\": \"$message\"}" - local HTTP_CODE_HEADER="HTTP Response Code: " - - echo "Attempting to post to Github..." - - local curl_output=$( - curl `#--dump-header -` \ - --silent \ - --user x-oauth-basic:$GITHUB_OAUTH_KEY \ - --request POST \ - --data "$data" \ - --write-out "${HTTP_CODE_HEADER}%{http_code}\n" \ - --header "Content-Type: application/json" \ - "$COMMENTS_URL" #> /dev/null #| "$FWDIR/dev/jq" .id #| head -n 8 - ) - local curl_status=${PIPESTATUS[0]} - - if [ "$curl_status" -ne 0 ]; then - echo "Failed to post message to GitHub." >&2 - echo " > curl_status: ${curl_status}" >&2 - echo " > curl_output: ${curl_output}" >&2 - echo " > data: ${data}" >&2 - # exit $curl_status - fi - - local api_response=$( - echo "${curl_output}" \ - | grep -v -e "^${HTTP_CODE_HEADER}" - ) - - local http_code=$( - echo "${curl_output}" \ - | grep -e "^${HTTP_CODE_HEADER}" \ - | sed -r -e "s/^${HTTP_CODE_HEADER}//g" - ) - - if [ -n "$http_code" ] && [ "$http_code" -ne "201" ]; then - echo " > http_code: ${http_code}." >&2 - echo " > api_response: ${api_response}" >&2 - echo " > data: ${data}" >&2 - fi - - if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then - echo " > Post successful." - fi -} - -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - -# post start message -{ - start_message="\ - [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - post_message "$start_message" -} - -# Environment variable to capture PR test output -pr_message="" -# Ensure we save off the current HEAD to revert to -current_pr_head="`git rev-parse HEAD`" - -echo "HEAD: `git rev-parse HEAD`" -echo "GHPRB: $ghprbActualCommit" -echo "SHA1: $sha1" - -# Run pull request tests -for t in "${PR_TESTS[@]}"; do - this_test="${FWDIR}/dev/tests/${t}.sh" - # Ensure the test can be found and is a file - if [ -f "${this_test}" ]; then - echo "Running test: $t" - this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" - # Check if this is the merge test as we submit that note *before* and *after* - # the tests run - [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" - pr_message="${pr_message}\n${this_mssg}" - # Ensure, after each test, that we're back on the current PR - git checkout -f "${current_pr_head}" &>/dev/null - else - echo "Cannot find test ${this_test}." - fi -done - -# run tests -{ - # Marks this build is a pull request build. - export AMP_JENKINS_PRB=true - timeout "${TESTS_TIMEOUT}" ./dev/run-tests - test_result="$?" - - if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ - for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ - after a configured wait of \`${TESTS_TIMEOUT}\`." - - post_message "$fail_message" - exit $test_result - elif [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes all tests**." - else - if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then - failing_test="some tests" - elif [ "$test_result" -eq "$BLOCK_RAT" ]; then - failing_test="RAT tests" - elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then - failing_test="Scala style tests" - elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then - failing_test="Python style tests" - elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then - failing_test="to generate documentation" - elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then - failing_test="to build" - elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then - failing_test="MiMa tests" - elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then - failing_test="Spark unit tests" - elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then - failing_test="PySpark unit tests" - elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then - failing_test="SparkR unit tests" - else - failing_test="some tests" - fi - - test_result_note=" * This patch **fails $failing_test**." - fi - - send_archived_logs -} - -# post end message -{ - result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - result_message="${result_message}\n${test_result_note}" - result_message="${result_message}${pr_message}" - - post_message "$result_message" -} - -exit $test_result +exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py new file mode 100755 index 000000000000..7aecea25b209 --- /dev/null +++ b/dev/run-tests-jenkins.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python2 + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import os +import sys +import json +import urllib2 +import functools +import subprocess + +from sparktestsupport import SPARK_HOME, ERROR_CODES +from sparktestsupport.shellutils import run_cmd + + +def print_err(msg): + """ + Given a set of arguments, will print them to the STDERR stream + """ + print(msg, file=sys.stderr) + + +def post_message_to_github(msg, ghprb_pull_id): + print("Attempting to post to Github...") + + url = "https://api.github.com/repos/apache/spark/issues/" + ghprb_pull_id + "/comments" + github_oauth_key = os.environ["GITHUB_OAUTH_KEY"] + + posted_message = json.dumps({"body": msg}) + request = urllib2.Request(url, + headers={ + "Authorization": "token %s" % github_oauth_key, + "Content-Type": "application/json" + }, + data=posted_message) + try: + response = urllib2.urlopen(request) + + if response.getcode() == 201: + print(" > Post successful.") + except urllib2.HTTPError as http_e: + print_err("Failed to post message to Github.") + print_err(" > http_code: %s" % http_e.code) + print_err(" > api_response: %s" % http_e.read()) + print_err(" > data: %s" % posted_message) + except urllib2.URLError as url_e: + print_err("Failed to post message to Github.") + print_err(" > urllib2_status: %s" % url_e.reason[1]) + print_err(" > data: %s" % posted_message) + + +def pr_message(build_display_name, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url, + msg, + post_msg=''): + # align the arguments properly for string formatting + str_args = (build_display_name, + msg, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url, + str(' ' + post_msg + '.') if post_msg else '.') + return '**[Test build %s %s](%sconsoleFull)** for PR %s at commit [`%s`](%s)%s' % str_args + + +def run_pr_checks(pr_tests, ghprb_actual_commit, sha1): + """ + Executes a set of pull request checks to ease development and report issues with various + components such as style, linting, dependencies, compatibilities, etc. + @return a list of messages to post back to Github + """ + # Ensure we save off the current HEAD to revert to + current_pr_head = run_cmd(['git', 'rev-parse', 'HEAD'], return_output=True).strip() + pr_results = list() + + for pr_test in pr_tests: + test_name = pr_test + '.sh' + pr_results.append(run_cmd(['bash', os.path.join(SPARK_HOME, 'dev', 'tests', test_name), + ghprb_actual_commit, sha1], + return_output=True).rstrip()) + # Ensure, after each test, that we're back on the current PR + run_cmd(['git', 'checkout', '-f', current_pr_head]) + return pr_results + + +def run_tests(tests_timeout): + """ + Runs the `dev/run-tests` script and responds with the correct error message + under the various failure scenarios. + @return a tuple containing the test result code and the result note to post to Github + """ + + test_result_code = subprocess.Popen(['timeout', + tests_timeout, + os.path.join(SPARK_HOME, 'dev', 'run-tests')]).wait() + + failure_note_by_errcode = { + 1: 'executing the `dev/run-tests` script', # error to denote run-tests script failures + ERROR_CODES["BLOCK_GENERAL"]: 'some tests', + ERROR_CODES["BLOCK_RAT"]: 'RAT tests', + ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', + ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', + ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', + ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', + ERROR_CODES["BLOCK_BUILD"]: 'to build', + ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', + ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', + ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', + ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( + tests_timeout) + } + + if test_result_code == 0: + test_result_note = ' * This patch passes all tests.' + else: + test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] + + return [test_result_code, test_result_note] + + +def main(): + # Important Environment Variables + # --- + # $ghprbActualCommit + # This is the hash of the most recent commit in the PR. + # The merge-base of this and master is the commit from which the PR was branched. + # $sha1 + # If the patch merges cleanly, this is a reference to the merge commit hash + # (e.g. "origin/pr/2606/merge"). + # If the patch does not merge cleanly, it is equal to $ghprbActualCommit. + # The merge-base of this and master in the case of a clean merge is the most recent commit + # against master. + ghprb_pull_id = os.environ["ghprbPullId"] + ghprb_actual_commit = os.environ["ghprbActualCommit"] + ghprb_pull_title = os.environ["ghprbPullTitle"] + sha1 = os.environ["sha1"] + + # Marks this build as a pull request build. + os.environ["AMP_JENKINS_PRB"] = "true" + # Switch to a Maven-based build if the PR title contains "test-maven": + if "test-maven" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_TOOL"] = "maven" + # Switch the Hadoop profile based on the PR title: + if "test-hadoop1.0" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" + if "test-hadoop2.0" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" + if "test-hadoop2.2" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" + if "test-hadoop2.3" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.3" + + build_display_name = os.environ["BUILD_DISPLAY_NAME"] + build_url = os.environ["BUILD_URL"] + + commit_url = "https://github.com/apache/spark/commit/" + ghprb_actual_commit + + # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( + short_commit_hash = ghprb_actual_commit[0:7] + + # format: http://linux.die.net/man/1/timeout + # must be less than the timeout configured on Jenkins (currently 300m) + tests_timeout = "250m" + + # Array to capture all test names to run on the pull request. These tests are represented + # by their file equivalents in the dev/tests/ directory. + # + # To write a PR test: + # * the file must reside within the dev/tests directory + # * be an executable bash script + # * accept three arguments on the command line, the first being the Github PR long commit + # hash, the second the Github SHA1 hash, and the final the current PR hash + # * and, lastly, return string output to be included in the pr message output that will + # be posted to Github + pr_tests = [ + "pr_merge_ability", + "pr_public_classes" + # DISABLED (pwendell) "pr_new_dependencies" + ] + + # `bind_message_base` returns a function to generate messages for Github posting + github_message = functools.partial(pr_message, + build_display_name, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url) + + # post start message + post_message_to_github(github_message('has started'), ghprb_pull_id) + + pr_check_results = run_pr_checks(pr_tests, ghprb_actual_commit, sha1) + + test_result_code, test_result_note = run_tests(tests_timeout) + + # post end message + result_message = github_message('has finished') + result_message += '\n' + test_result_note + '\n' + result_message += '\n'.join(pr_check_results) + + post_message_to_github(result_message, ghprb_pull_id) + + sys.exit(test_result_code) + + +if __name__ == "__main__": + main() diff --git a/dev/run-tests.py b/dev/run-tests.py index b6d181418f02..e7e10f1d8c72 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -21,15 +21,17 @@ import itertools from optparse import OptionParser import os +import random import re import sys import subprocess from collections import namedtuple -from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport import SPARK_HOME, USER_HOME, ERROR_CODES from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which import sparktestsupport.modules as modules + # ------------------------------------------------------------------------------------------------- # Functions for traversing module dependency graph # ------------------------------------------------------------------------------------------------- @@ -117,23 +119,18 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) +def determine_tags_to_exclude(changed_modules): + tags = [] + for m in modules.all_modules: + if m not in changed_modules: + tags += m.test_tags + return tags + + # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- -def get_error_codes(err_code_file): - """Function to retrieve all block numbers from the `run-tests-codes.sh` - file to maintain backwards compatibility with the `run-tests-jenkins` - script""" - - with open(err_code_file, 'r') as f: - err_codes = [e.split()[1].strip().split('=') - for e in f if e.startswith("readonly")] - return dict(err_codes) - - -ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) - def determine_java_executable(): """Will return the path of the java executable that will be used by Spark's @@ -167,17 +164,14 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' - version, update = version_str.split('_') # eg ['1.8.0', '25'] - - # map over the values and convert them to integers - version_info = [int(x) for x in version.split('.') + [update]] + match = re.search('(\d+)\.(\d+)\.(\d+)_(\d+)', raw_version_str) - return JavaVersion(major=version_info[0], - minor=version_info[1], - patch=version_info[2], - update=version_info[3]) + major = int(match.group(1)) + minor = int(match.group(2)) + patch = int(match.group(3)) + update = int(match.group(4)) + return JavaVersion(major, minor, patch, update) # ------------------------------------------------------------------------------------------------- # Functions for running the other build and test scripts @@ -185,7 +179,7 @@ def determine_java_version(java_exe): def set_title_and_block(title, err_block): - os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] + os.environ["CURRENT_BLOCK"] = str(ERROR_CODES[err_block]) line_str = '=' * 72 print('') @@ -204,11 +198,28 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) +def run_sparkr_style_checks(): + set_title_and_block("Running R style checks", "BLOCK_R_STYLE") + + if which("R"): + # R style check should be executed after `install-dev.sh`. + # Since warnings about `no visible global function definition` appear + # without the installation. SEE ALSO: SPARK-9121. + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-r")]) + else: + print("Ignoring SparkR style check as R was not found in PATH") + + def build_spark_documentation(): set_title_and_block("Building Spark Documentation", "BLOCK_DOCUMENTATION") os.environ["PRODUCTION"] = "1 jekyll build" @@ -227,11 +238,32 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) +def get_zinc_port(): + """ + Get a randomized port on which to start Zinc + """ + return random.randrange(3030, 4030) + + +def kill_zinc_on_port(zinc_port): + """ + Kill the Zinc process running on the given port, if one exists. + """ + cmd = ("/usr/sbin/lsof -P |grep %s | grep LISTEN " + "| awk '{ print $2; }' | xargs kill") % zinc_port + subprocess.check_call(cmd, shell=True) + + def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" - run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) + zinc_port = get_zinc_port() + os.environ["ZINC_PORT"] = "%s" % zinc_port + zinc_flag = "-DzincPort=%s" % zinc_port + flags = [os.path.join(SPARK_HOME, "build", "mvn"), "--force", zinc_flag] + run_cmd(flags + mvn_args) + kill_zinc_on_port(zinc_port) def exec_sbt(sbt_args=()): @@ -273,6 +305,7 @@ def get_hadoop_profiles(hadoop_version): "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } if hadoop_version in sbt_maven_hadoop_profiles: @@ -289,7 +322,7 @@ def build_spark_maven(hadoop_version): mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using Maven with these arguments: ", " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -302,17 +335,19 @@ def build_spark_sbt(hadoop_version): "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) def build_apache_spark(build_tool, hadoop_version): - """Will build Spark against Hive v0.13.1 given the passed in build tool (either `sbt` or + """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or `maven`). Defaults to using `sbt`.""" set_title_and_block("Building Spark", "BLOCK_BUILD") @@ -332,6 +367,7 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -355,7 +391,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules): +def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -364,6 +400,10 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + + if excluded_tags: + test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] + if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -384,7 +424,6 @@ def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") if which("R"): - run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: print("Ignoring SparkR tests as R was not found in PATH") @@ -421,7 +460,7 @@ def main(): rm_r(os.path.join(USER_HOME, ".ivy2", "local", "org.apache.spark")) rm_r(os.path.join(USER_HOME, ".ivy2", "cache", "org.apache.spark")) - os.environ["CURRENT_BLOCK"] = ERROR_CODES["BLOCK_GENERAL"] + os.environ["CURRENT_BLOCK"] = str(ERROR_CODES["BLOCK_GENERAL"]) java_exe = determine_java_executable() @@ -435,6 +474,12 @@ def main(): if java_version.minor < 8: print("[warn] Java 8 tests will not run because JDK version is < 1.8.") + # install SparkR + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) + else: + print("Can't install SparkR as R is was not found in PATH") + if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables # to reflect the environment settings @@ -446,7 +491,7 @@ def main(): else: # else we're running locally and can use local settings build_tool = "sbt" - hadoop_version = "hadoop2.3" + hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.3") test_env = "local" print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, @@ -458,8 +503,10 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] + excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -480,8 +527,12 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() + if not changed_files or any(f.endswith(".R") for f in changed_files): + run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed @@ -492,10 +543,12 @@ def main(): build_apache_spark(build_tool, hadoop_version) # backwards compatibility checks - detect_binary_inop_with_mima() + if build_tool == "sbt": + # Note: compatiblity tests only supported in sbt for now + detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules) + run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/scalastyle b/dev/scalastyle index ad93f7e85b27..8fd3604b9f45 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,14 +17,17 @@ # limitations under the License. # -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver scalastyle > scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt -# Check style with YARN built too -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt - -ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') -rm scalastyle.txt +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pyarn \ + -Phive \ + -Phive-thriftserver \ + scalastyle test:scalastyle \ + | awk '{if($1~/error/)print}' \ +) if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 12696d98fb98..0e8032d13341 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -19,3 +19,18 @@ SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) USER_HOME = os.environ.get("HOME") +ERROR_CODES = { + "BLOCK_GENERAL": 10, + "BLOCK_RAT": 11, + "BLOCK_SCALA_STYLE": 12, + "BLOCK_PYTHON_STYLE": 13, + "BLOCK_R_STYLE": 14, + "BLOCK_DOCUMENTATION": 15, + "BLOCK_BUILD": 16, + "BLOCK_MIMA": 17, + "BLOCK_SPARK_UNIT_TESTS": 18, + "BLOCK_PYSPARK_UNIT_TESTS": 19, + "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, + "BLOCK_TIMEOUT": 124 +} diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 956dc81b62e9..d65547e04db4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - should_run_r_tests=False): + test_tags=(), should_run_r_tests=False): """ Define a new module. @@ -50,6 +50,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. + :param test_tags A set of tags that will be excluded when running unit tests if the module + is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -60,6 +62,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations + self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -85,6 +88,9 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", + ], + test_tags=[ + "org.apache.spark.tags.ExtendedHiveTest" ] ) @@ -134,7 +140,7 @@ def contains_file(self, filename): # files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't # fail other PRs. streaming_kinesis_asl = Module( - name="kinesis-asl", + name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", @@ -147,7 +153,7 @@ def contains_file(self, filename): "ENABLE_KINESIS_TESTS": "1" }, sbt_test_goals=[ - "kinesis-asl/test", + "streaming-kinesis-asl/test", ] ) @@ -181,6 +187,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -306,6 +313,7 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, + streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ @@ -331,6 +339,7 @@ def contains_file(self, filename): "pyspark.mllib.feature", "pyspark.mllib.fpm", "pyspark.mllib.linalg.__init__", + "pyspark.mllib.linalg.distributed", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", @@ -395,6 +404,22 @@ def contains_file(self, filename): ) +yarn = Module( + name="yarn", + dependencies=[], + source_file_regexes=[ + "yarn/", + "network/yarn/", + ], + sbt_test_goals=[ + "yarn/test", + "network-yarn/test", + ], + test_tags=[ + "org.apache.spark.tags.ExtendedYarnTest" + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index 12bd0bf3a4fe..d280e797077d 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -22,6 +22,36 @@ import sys +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output + subprocess_check_call = subprocess.check_call +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + + # backported from subprocess module in Python 2.7 + def subprocess_check_call(*popenargs, **kwargs): + retcode = call(*popenargs, **kwargs) + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise CalledProcessError(retcode, cmd) + return 0 + + def exit_from_command_with_retcode(cmd, retcode): print("[error] running", ' '.join(cmd), "; received return code", retcode) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) @@ -39,7 +69,7 @@ def rm_r(path): os.remove(path) -def run_cmd(cmd): +def run_cmd(cmd, return_output=False): """ Given a command as a list of arguments will attempt to execute the command and, on failure, print an error message and exit. @@ -48,7 +78,10 @@ def run_cmd(cmd): if not isinstance(cmd, list): cmd = cmd.split() try: - subprocess.check_call(cmd) + if return_output: + return subprocess_check_output(cmd) + else: + return subprocess_check_call(cmd) except subprocess.CalledProcessError as e: exit_from_command_with_retcode(e.cmd, e.returncode) diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml new file mode 100644 index 000000000000..39d3f344615e --- /dev/null +++ b/docker-integration-tests/pom.xml @@ -0,0 +1,171 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + spark-docker-integration-tests_2.10 + jar + Spark Project Docker Integration Tests + http://spark.apache.org/ + + docker-integration-tests + + + + + com.spotify + docker-client + shaded + test + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + + + com.fasterxml.jackson.core + jackson-databind + + + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.connectors + jersey-apache-connector + + + org.glassfish.jersey.media + jersey-media-json-jackson + + + + + org.apache.httpcomponents + httpclient + 4.5 + test + + + org.apache.httpcomponents + httpcore + 4.4.1 + test + + + + com.google.guava + guava + 18.0 + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + + + mysql + mysql-connector-java + test + + + org.postgresql + postgresql + test + + + + com.sun.jersey + jersey-server + 1.19 + test + + + com.sun.jersey + jersey-core + 1.19 + test + + + com.sun.jersey + jersey-servlet + 1.19 + test + + + com.sun.jersey + jersey-json + 1.19 + test + + + stax + stax-api + + + + + + diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala new file mode 100644 index 000000000000..c503c4a13b48 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.net.ServerSocket +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.spotify.docker.client._ +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.DockerUtils +import org.apache.spark.sql.test.SharedSQLContext + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled. + */ + val imageName: String + + /** + * Environment variables to set inside of the Docker container while launching it. + */ + val env: Map[String, String] + + /** + * The container-internal JDBC port that the database listens on. + */ + val jdbcPort: Int + + /** + * Return a JDBC URL that connects to the database running at the given IP address and port. + */ + def getJdbcUrl(ip: String, port: Int): String +} + +abstract class DockerJDBCIntegrationSuite + extends SparkFunSuite + with BeforeAndAfterAll + with Eventually + with SharedSQLContext { + + val db: DatabaseOnDocker + + private var docker: DockerClient = _ + private var containerId: String = _ + protected var jdbcUrl: String = _ + + override def beforeAll() { + super.beforeAll() + try { + docker = DefaultDockerClient.fromEnv.build() + // Check that Docker is actually up + try { + docker.ping() + } catch { + case NonFatal(e) => + log.error("Exception while connecting to Docker. Check whether Docker is running.") + throw e + } + // Ensure that the Docker image is installed: + try { + docker.inspectImage(db.imageName) + } catch { + case e: ImageNotFoundException => + log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") + docker.pull(db.imageName) + } + // Configure networking (necessary for boot2docker / Docker Machine) + val externalPort: Int = { + val sock = new ServerSocket(0) + val port = sock.getLocalPort + sock.close() + port + } + val dockerIp = DockerUtils.getDockerIp() + val hostConfig: HostConfig = HostConfig.builder() + .networkMode("bridge") + .portBindings( + Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) + .build() + // Create the database container: + val config = ContainerConfig.builder() + .image(db.imageName) + .networkDisabled(false) + .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) + .hostConfig(hostConfig) + .exposedPorts(s"${db.jdbcPort}/tcp") + .build() + containerId = docker.createContainer(config).id + // Start the container and wait until the database can accept JDBC connections: + docker.startContainer(containerId) + jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) + eventually(timeout(60.seconds), interval(1.seconds)) { + val conn = java.sql.DriverManager.getConnection(jdbcUrl) + conn.close() + } + // Run any setup queries: + val conn: Connection = java.sql.DriverManager.getConnection(jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } catch { + case NonFatal(e) => + try { + afterAll() + } finally { + throw e + } + } + } + + override def afterAll() { + try { + if (docker != null) { + try { + if (containerId != null) { + docker.killContainer(containerId) + docker.removeContainer(containerId) + } + } catch { + case NonFatal(e) => + logWarning(s"Could not stop container $containerId", e) + } finally { + docker.close() + } + } + } finally { + super.afterAll() + } + } + + /** + * Prepare databases and tables for testing. + */ + def dataPreparation(connection: Connection): Unit +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 000000000000..c68e4dc4933b --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "mysql:5.7.9" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 000000000000..6eb6b3391a4a --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Literal, If} +import org.apache.spark.tags.DockerTest + +@DockerTest +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "postgres:9.4.5" + override val env = Map( + "POSTGRES_PASSWORD" -> "rootpass" + ) + override val jdbcPort = 5432 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " + + "c10 integer[], c11 text[])").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " + + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass) + assert(types.length == 12) + assert(classOf[String].isAssignableFrom(types(0))) + assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) + assert(classOf[java.lang.Double].isAssignableFrom(types(2))) + assert(classOf[java.lang.Long].isAssignableFrom(types(3))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(4))) + assert(classOf[Array[Byte]].isAssignableFrom(types(5))) + assert(classOf[Array[Byte]].isAssignableFrom(types(6))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(7))) + assert(classOf[String].isAssignableFrom(types(8))) + assert(classOf[String].isAssignableFrom(types(9))) + assert(classOf[Seq[Int]].isAssignableFrom(types(10))) + assert(classOf[Seq[String]].isAssignableFrom(types(11))) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + assert(rows(0).getSeq(10) == Seq(1, 2)) + assert(rows(0).getSeq(11) == Seq("a", null, "b")) + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + // Test only that it doesn't crash. + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test write null values. + df.select(df.queryExecution.analyzed.output.map { a => + Column(Literal.create(null, a.dataType)).as(a.name) + }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala new file mode 100644 index 000000000000..87271776d856 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{Inet4Address, NetworkInterface, InetAddress} + +import scala.collection.JavaConverters._ +import scala.sys.process._ +import scala.util.Try + +private[spark] object DockerUtils { + + def getDockerIp(): String = { + /** If docker-machine is setup on this box, attempts to find the ip from it. */ + def findFromDockerMachine(): Option[String] = { + sys.env.get("DOCKER_MACHINE_NAME").flatMap { name => + Try(Seq("/bin/bash", "-c", s"docker-machine ip $name 2>/dev/null").!!.trim).toOption + } + } + sys.env.get("DOCKER_IP") + .orElse(findFromDockerMachine()) + .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) + .getOrElse { + // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // certain interfaces. + val address = InetAddress.getLocalHost + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val blackListedIFs = Seq( + "vboxnet0", // Mac + "docker0" // Linux + ) + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => + !blackListedIFs.contains(i.getName) + } + val reOrderedNetworkIFs = activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) + return strippedAddress.getHostAddress + } + } + address.getHostAddress + } + } +} diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile index b90aef3655de..fb3f267fe5c7 100644 --- a/docker/spark-mesos/Dockerfile +++ b/docker/spark-mesos/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && \ apt-get install -y python libnss3 openjdk-7-jre-headless curl RUN mkdir /opt/spark && \ - curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + curl http://www.apache.org/dyn/closer.lua/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ | tar -xzC /opt ENV SPARK_HOME /opt/spark ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5dbdb8b22a44..7ba0de603dc7 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.10.4 +ENV SCALA_VERSION 2.10.5 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/README.md b/docs/README.md index d7652e921f7d..1f4fd3e56ed5 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,17 @@ Read on to learn more about viewing documentation in plain text (i.e., markdown) documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control. +## Prerequisites +The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala, +Python and R. To get started you can run the following commands + + $ sudo gem install jekyll + $ sudo gem install jekyll-redirect-from + $ sudo pip install Pygments + $ sudo pip install sphinx + $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' + + ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as @@ -19,17 +30,12 @@ you have checked out or downloaded. In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can read those text files directly if you want. Start with index.md. -The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). -`Jekyll` and a few dependencies must be installed for this to work. We recommend -installing via the Ruby Gem dependency manager. Since the exact HTML output -varies between versions of Jekyll and its dependencies, we list specific versions here -in some cases: - - $ sudo gem install jekyll - $ sudo gem install jekyll-redirect-from +Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with +Jekyll will create a directory called `_site` containing index.html as well as the rest of the +compiled files. -Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory -called `_site` containing index.html as well as the rest of the compiled files. + $ cd docs + $ jekyll build You can modify the default Jekyll build as follows: @@ -40,29 +46,6 @@ You can modify the default Jekyll build as follows: # Build the site with extra features used on the live page $ PRODUCTION=1 jekyll build -## Pygments - -We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. - -To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile -phase, use the following sytax: - - {% highlight scala %} - // Your scala code goes here, you can replace scala with many other - // supported languages too. - {% endhighlight %} - -## Sphinx - -We use Sphinx to generate Python API docs, so you will need to install it by running -`sudo pip install sphinx`. - -## knitr, devtools - -SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate -documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a -R console. ## API Docs (Scaladoc, Sphinx, roxygen2) diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9..2c70b76be8b7 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,10 +14,10 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.6.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.4" +SCALA_VERSION: "2.10.5" MESOS_VERSION: 0.21.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 000000000000..2eea9a917a4c --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: "Overview: estimators, transformers and pipelines" + url: ml-guide.html +- text: Extracting, transforming and selecting features + url: ml-features.html +- text: Classification and Regression + url: ml-classification-regression.html +- text: Clustering + url: ml-clustering.html +- text: Advanced topics + url: ml-advanced.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 000000000000..12d22abd5282 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 000000000000..e2d7eda027c6 --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 000000000000..73176f413255 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b4952fe97ca0..3089474c1338 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,4 @@ + @@ -71,7 +72,7 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • DataFrames and SQL
  • +
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • @@ -112,7 +113,6 @@
  • Job Scheduling
  • Security
  • Hardware Provisioning
  • -
  • 3rd-Party Hadoop Distros
  • Building Spark
  • Contributing to Spark
  • @@ -125,16 +125,36 @@
    -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    + + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} - {{ content }} + {{ content }} + +
    + {% else %} +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} -
    + {{ content }} + +
    + {% endif %} + +
    diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a8..174c202e3791 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" @@ -114,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - puts `make html` + system("make html") || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") @@ -128,7 +131,7 @@ # Build SparkR API docs puts "Moving to R directory and building roxygen docs." cd("R") - puts `./create-docs.sh` + system("./create-docs.sh") || raise("R doc generation failed") puts "Moving back into home dir." cd("../") diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb new file mode 100644 index 000000000000..f7485826a762 --- /dev/null +++ b/docs/_plugins/include_example.rb @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +require 'liquid' +require 'pygments' + +module Jekyll + class IncludeExampleTag < Liquid::Tag + + def initialize(tag_name, markup, tokens) + @markup = markup + super + end + + def render(context) + site = context.registers[:site] + config_dir = '../examples/src/main' + @code_dir = File.join(site.source, config_dir) + + clean_markup = @markup.strip + @file = File.join(@code_dir, clean_markup) + @lang = clean_markup.split('.').last + + code = File.open(@file).read.encode("UTF-8") + code = select_lines(code) + + rendered_code = Pygments.highlight(code, :lexer => @lang) + + hint = "
    Find full example code at " \ + "\"examples/src/main/#{clean_markup}\" in the Spark repo.
    " + + rendered_code + hint + end + + # Trim the code block so as to have the same indention, regardless of their positions in the + # code file. + def trim_codeblock(lines) + # Select the minimum indention of the current code block. + min_start_spaces = lines + .select { |l| l.strip.size !=0 } + .map { |l| l[/\A */].size } + .min + + lines.map { |l| l.strip.size == 0 ? l : l[min_start_spaces .. -1] } + end + + # Select lines according to labels in code. Currently we use "$example on$" and "$example off$" + # as labels. Note that code blocks identified by the labels should not overlap. + def select_lines(code) + lines = code.each_line.to_a + + # Select the array of start labels from code. + startIndices = lines + .each_with_index + .select { |l, i| l.include? "$example on$" } + .map { |l, i| i } + + # Select the array of end labels from code. + endIndices = lines + .each_with_index + .select { |l, i| l.include? "$example off$" } + .map { |l, i| i } + + raise "Start indices amount is not equal to end indices amount, see #{@file}." \ + unless startIndices.size == endIndices.size + + raise "No code is selected by include_example, see #{@file}." \ + if startIndices.size == 0 + + # Select and join code blocks together, with a space line between each of two continuous + # blocks. + lastIndex = -1 + result = "" + startIndices.zip(endIndices).each do |start, endline| + raise "Overlapping between two example code blocks are not allowed, see #{@file}." \ + if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$, see #{@file}." \ + if start == endline + lastIndex = endline + range = Range.new(start + 1, endline - 1) + result += trim_codeblock(lines[range]).join + result += "\n" + end + result + end + end +end + +Liquid::Template.register_tag('include_example', Jekyll::IncludeExampleTag) diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286c..347ca4a7af98 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/building-spark.md b/docs/building-spark.md index a5da3b39502e..3d38edbdad4b 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,7 +7,8 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+. +Building Spark using Maven requires Maven 3.3.3 or newer and Java 7+. +The Spark build can supply a suitable Maven binary; see below. # Building with `build/mvn` @@ -37,7 +38,7 @@ To create a Spark distribution like those distributed by the to be runnable, use `make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: - ./make-distribution.sh --name custom-spark --tgz -Phadoop-2.4 -Pyarn + ./make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn For more information on usage, run `./make-distribution.sh --help` @@ -60,12 +61,13 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* *For Java 8 and above this step is not required.* -* *If using `build/mvn` and `MAVEN_OPTS` were not already set, the script will automate this for you.* + +* For Java 8 and above this step is not required. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. # Specifying the Hadoop Version -Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: +Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the `hadoop.version` property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: @@ -90,7 +92,7 @@ mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package {% endhighlight %} -You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. +You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. Examples: @@ -124,7 +126,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. @@ -142,6 +144,17 @@ The ScalaTest plugin also supports running only a specific test suite as follows mvn -Dhadoop.version=... -DwildcardSuites=org.apache.spark.repl.ReplSuite test +# Building submodules individually + +It's possible to build Spark sub-modules using the `mvn -pl` option. + +For instance, you can build the Spark Streaming module using: + +{% highlight bash %} +mvn -pl :spark-streaming_2.10 clean install +{% endhighlight %} + +where `spark-streaming_2.10` is the `artifactId` as defined in `streaming/pom.xml` file. # Continuous Compilation @@ -162,11 +175,9 @@ the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: -``` - $ mvn install - $ cd core - $ mvn scala:cc -``` + $ mvn install + $ cd core + $ mvn scala:cc # Building Spark with IntelliJ IDEA or Eclipse @@ -179,6 +190,10 @@ Running only Java 8 tests and nothing else. mvn install -DskipTests -Pjava8-tests +or + + sbt -Pjava8-tests java8-tests/test + Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. For these tests to run your system must have a JDK 8 installation. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. @@ -192,11 +207,11 @@ then ship it over to the cluster. We are investigating the exact cause for this. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT -Maven is the official recommendation for packaging Spark, and is the "build of reference". +Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. But SBT is supported for day-to-day development since it can provide much faster iterative compilation. More advanced developers may wish to use SBT. @@ -205,6 +220,11 @@ can be set to control the SBT build. For example: build/sbt -Pyarn -Phadoop-2.3 assembly +To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt +in interactive mode by running `build/sbt`, and then run all build commands at the command +prompt. For more recommendations on reducing build time, refer to the +[wiki page](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-ReducingBuildTimes). + # Testing with SBT Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 7079de546e2f..faaf154d243f 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -5,18 +5,19 @@ title: Cluster Mode Overview This document gives a short overview of how Spark runs on clusters, to make it easier to understand the components involved. Read through the [application submission guide](submitting-applications.html) -to submit applications to a cluster. +to learn about launching applications on a cluster. # Components -Spark applications run as independent sets of processes on a cluster, coordinated by the SparkContext +Spark applications run as independent sets of processes on a cluster, coordinated by the `SparkContext` object in your main program (called the _driver program_). + Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ -(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across +(either Spark's own standalone cluster manager, Mesos or YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to -the executors. Finally, SparkContext sends *tasks* for the executors to run. +the executors. Finally, SparkContext sends *tasks* to the executors to run.

    Spark cluster components @@ -33,9 +34,9 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config - section](configuration.html#networking)). As such, the driver program must be network +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the diff --git a/docs/configuration.md b/docs/configuration.md index 24b606356a14..38d3d059f9d3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,21 +34,21 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may -actually require one to prevent any sort of starvation issues. +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +actually require more than 1 thread to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - -Properties that specify a byte size should be configured with a unit of size. + + +Properties that specify a byte size should be configured with a unit of size. The following format is accepted: 1b (bytes) @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -140,7 +140,7 @@ of the most common options to set are:

    + + + + +
    Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -192,6 +192,15 @@ of the most common options to set are: allowed master URL's.
    spark.submit.deployMode(none) + The deploy mode of Spark driver program, either "client" or "cluster", + Which means to launch driver program locally ("client") + or remotely ("cluster") on one of the nodes inside the cluster. +
    Apart from these, the following properties are also available, and may be useful in some situations: @@ -207,7 +216,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +225,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +237,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +251,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +259,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +268,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +314,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +339,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -382,16 +391,6 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. - - spark.shuffle.blockTransferService - netty - - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. - - spark.shuffle.compress true @@ -400,16 +399,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. - - spark.shuffle.consolidateFiles - false - - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. - - spark.shuffle.file.buffer 32k @@ -458,35 +447,35 @@ Apart from these, the following properties are also available, and may be useful sort Implementation to use for shuffling data. There are two implementations available: - sort and hash. Sort-based shuffle is more memory-efficient and is - the default option starting in 1.2. + sort and hash. + Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. - spark.shuffle.memoryFraction - 0.2 + spark.shuffle.service.enabled + false - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of - all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will - begin to spill to disk. If spills are often, consider increasing this value at the expense of - spark.storage.memoryFraction. + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if + spark.dynamicAllocation.enabled is "true". The external shuffle service + must be set up in order to enable it. See + dynamic allocation + configuration and setup documentation for more information. - spark.shuffle.sort.bypassMergeThreshold - 200 + spark.shuffle.service.port + 7337 - (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no - map-side aggregation and there are at most this many reduce partitions. + Port on which the external shuffle service will run. - spark.shuffle.spill - true + spark.shuffle.sort.bypassMergeThreshold + 200 - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. + (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no + map-side aggregation and there are at most this many reduce partitions. @@ -571,6 +560,20 @@ Apart from these, the following properties are also available, and may be useful How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + spark.sql.ui.retainedExecutions + 1000 + + How many finished executions the Spark UI and status APIs remember before garbage collecting. + + + + spark.streaming.ui.retainedBatches + 1000 + + How many finished batches the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization @@ -653,10 +656,10 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This + If you use Kryo serialization, give a comma-separated list of classes that register your custom classes with Kryo. This property is useful if you need to register your classes in a custom way, e.g. to specify a custom field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends + set to classes that extend KryoRegistrator. See the tuning guide for more details. @@ -718,6 +721,95 @@ Apart from these, the following properties are also available, and may be useful +#### Memory Management + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.memory.fraction0.75 + Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the + more frequently spills and cached data eviction occur. The purpose of this config is to set + aside memory for internal metadata, user data structures, and imprecise size estimation + in the case of sparse, unusually large records. Leaving this at the default value is + recommended. For more detail, see + this description. +
    spark.memory.storageFraction0.5 + Amount of storage memory immune to eviction, expressed as a fraction of the size of the + region set aside by s​park.memory.fraction. The higher this is, the less + working memory may be available to execution and tasks may spill to disk more often. + Leaving this at the default value is recommended. For more detail, see + this description. +
    spark.memory.offHeap.enabledtrue + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. +
    spark.memory.offHeap.size0 + The absolute amount of memory which can be used for off-heap allocation. + This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This must be set to a positive value when spark.memory.offHeap.enabled=true. +
    spark.memory.useLegacyModefalse + ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + The legacy mode rigidly partitions the heap space into fixed-size regions, + potentially leading to excessive spilling if the application was not tuned. + The following deprecated memory fraction configurations are not read unless this is enabled: + spark.shuffle.memoryFraction
    + spark.storage.memoryFraction
    + spark.storage.unrollFraction +
    spark.shuffle.memoryFraction0.2 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + spark.storage.memoryFraction. +
    spark.storage.memoryFraction0.6 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" + generation of objects in the JVM, which by default is given 0.6 of the heap, but you can + increase it if you configure your own old generation size. +
    spark.storage.unrollFraction0.2 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory. + This is dynamically allocated by dropping existing blocks when there is not enough free + storage space to unroll the new block in its entirety. +
    + #### Execution Behavior @@ -753,9 +845,9 @@ Apart from these, the following properties are also available, and may be useful @@ -830,15 +922,6 @@ Apart from these, the following properties are also available, and may be useful This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since data may need to be rewritten to pre-existing output directories during checkpoint recovery. - - - - - @@ -848,15 +931,6 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. - - - - - @@ -886,21 +960,11 @@ Apart from these, the following properties are also available, and may be useful #### Networking
    Property NameDefaultMeaning
    1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker.
    spark.storage.memoryFraction0.6 - Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" - generation of objects in the JVM, which by default is given 0.6 of the heap, but you can - increase it if you configure your own old generation size. -
    spark.storage.memoryMapThreshold 2m
    spark.storage.unrollFraction0.2 - Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory. - This is dynamically allocated by dropping existing blocks when there is not enough free - storage space to unroll the new block in its entirety. -
    spark.externalBlockStore.blockManager org.apache.spark.storage.TachyonBlockManager
    - - - - - @@ -909,14 +973,14 @@ Apart from these, the following properties are also available, and may be useful @@ -925,9 +989,9 @@ Apart from these, the following properties are also available, and may be useful @@ -981,6 +1045,7 @@ Apart from these, the following properties are also available, and may be useful @@ -988,13 +1053,14 @@ Apart from these, the following properties are also available, and may be useful - - - - - @@ -1042,7 +1104,7 @@ Apart from these, the following properties are also available, and may be useful
    Property NameDefaultMeaning
    spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
    spark.akka.frameSize 128 - Maximum message size to allow in "control plane" communication; generally only applies to map + Maximum message size (in MB) to allow in "control plane" communication; generally only applies to map output size information sent between executors and the driver. Increase this if you are running jobs with many thousands of map and reduce tasks and see messages about the frame size. spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those.
    6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to.
    (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend.
    (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend.
    spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -1005,15 +1071,11 @@ Apart from these, the following properties are also available, and may be useful spark.port.maxRetries 16 - Default maximum number of retries when binding to a port before giving up. -
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. + Maximum number of retries when binding to a port before giving up. + When a port is given a specific value (non 0), each subsequent retry will + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified + to port + maxRetries.
    spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out.
    @@ -1106,10 +1168,11 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.minRegisteredResourcesRatio - 0.8 for YARN mode; 0.0 otherwise + 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode) + (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained + mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config @@ -1202,7 +1265,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1276,7 +1339,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. + help debug when things work. Putting a "*" in the list means any user can have the priviledge + of admin. @@ -1295,6 +1359,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s @@ -1317,7 +1397,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). + user that started the Spark job has access to modify it (kill it for example). Putting a "*" in + the list means any user can have access to modify it. @@ -1339,7 +1420,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. + user that started the Spark job has view access. Putting a "*" in the list means any user can + have view access to this Spark job. @@ -1427,6 +1509,19 @@ Apart from these, the following properties are also available, and may be useful #### Spark Streaming + + + + + @@ -1469,6 +1564,14 @@ Apart from these, the following properties are also available, and may be useful higher memory usage in Spark. + + + + + @@ -1508,6 +1611,20 @@ Apart from these, the following properties are also available, and may be useful Number of threads used by RBackend to handle RPC calls from SparkR package. + + + + + + + + + +
    Property NameDefaultMeaning
    spark.streaming.backpressure.enabledfalse + Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). + This enables the Spark Streaming to control the receiving rate based on the + current batch scheduling delays and processing times so that the system receives + only as fast as the system can process. Internally, this dynamically sets the + maximum receiving rate of receivers. This rate is upper bounded by the values + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition + if they are set (see below). +
    spark.streaming.blockInterval 200ms
    spark.streaming.stopGracefullyOnShutdownfalse + If true, Spark shuts down the StreamingContext gracefully on JVM + shutdown rather than immediately. +
    spark.streaming.kafka.maxRatePerPartition not set
    spark.r.commandRscript + Executable for executing R scripts in cluster modes for both driver and workers. +
    spark.r.driver.commandspark.r.command + Executable for executing R scripts in client modes for driver. Ignored in cluster modes. +
    #### Cluster Managers @@ -1537,11 +1654,19 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark. + Python binary executable to use for PySpark in both driver and workers (default is python). + + + PYSPARK_DRIVER_PYTHON + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + + + SPARKR_DRIVER_R + R binary executable to use for SparkR shell (default is R). SPARK_LOCAL_IP @@ -1572,3 +1697,17 @@ To specify a different configuration directory other than the default "SPARK_HOM you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. +# Inheriting Hadoop Cluster Configuration + +If you plan to read and write from HDFS using Spark, there are two Hadoop configuration files that +should be included on Spark's classpath: + +* `hdfs-site.xml`, which provides default behaviors for the HDFS client. +* `core-site.xml`, which sets the default filesystem name. + +The location of these configuration files varies across CDH and HDP versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +configurations on-the-fly, but offer a mechanisms to download copies of them. + +To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` +to a location containing the configuration files. diff --git a/docs/css/main.css b/docs/css/main.css index 89305a7d3a35..175e8004fca0 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,15 @@ margin-left: 10px; } -body #content { - line-height: 1.6; /* Inspired by Github's wiki style */ +body .container-wrapper { + background-color: #FFF; + color: #1D1F22; + max-width: 1024px; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + border-radius: 15px; + position: relative; } .title { @@ -74,6 +81,10 @@ code { color: #444444; } +div .highlight pre { + font-size: 12px; +} + a code { color: #0088cc; } @@ -87,6 +98,24 @@ a:hover code { max-width: 914px; } +.content { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 15px; +} + +.content-with-sidebar { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 30px; +} + .dropdown-menu { /* Remove the default 2px top margin which causes a small gap between the hover trigger area and the popup menu */ @@ -151,3 +180,110 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + margin-left: 0px; + margin-right: 0px; + background-color: #F0F8FC; + border-top-width: 0px; + border-left-width: 0px; + border-bottom-width: 0px; + margin-top: 0px; + width: 210px; + float: left; + position: absolute; +} + +.left-menu { + padding: 0px; + width: 199px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} + +/** + * The collapsing button for the navigation bar. + */ +.nav-trigger { + position: fixed; + clip: rect(0, 0, 0, 0); +} + +.nav-trigger + label:after { + content: '»'; +} + +label { + z-index: 10; +} + +label[for="nav-trigger"] { + position: fixed; + margin-left: 0px; + padding-top: 100px; + padding-left: 5px; + width: 10px; + height: 80%; + cursor: pointer; + background-size: contain; + background-color: #D4F0FF; +} + +label[for="nav-trigger"]:hover { + background-color: #BEE9FF; +} + +.nav-trigger:checked + label { + margin-left: 200px; +} + +.nav-trigger:checked + label:after { + content: '«'; +} + +.nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 200px; +} + +.nav-trigger + label, div.content-with-sidebar { + transition: left 0.4s; +} + +/** + * Rules to collapse the menu automatically when the screen becomes too thin. + */ + +@media all and (max-width: 780px) { + + div.content-with-sidebar { + margin-left: 200px; + } + .nav-trigger + label:after { + content: '«'; + } + label[for="nav-trigger"] { + margin-left: 200px; + } + + .nav-trigger:checked + label { + margin-left: 0px; + } + .nav-trigger:checked + label:after { + content: '»'; + } + .nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 0px; + } + + div.container-index { + margin-left: -215px; + } + +} \ No newline at end of file diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 99f8c827f767..9dea9b5904d2 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -70,7 +70,7 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato ## Migrating from Spark 1.1 -GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: +GraphX in Spark 1.2 contains a few user facing API changes: 1. To improve performance we have introduced a new version of [`mapReduceTriplets`][Graph.mapReduceTriplets] called @@ -768,16 +768,14 @@ class GraphOps[VD, ED] { // Loop until no messages remain or maxIterations is achieved var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages: ----------------------------------------------------------------------- - // Run the vertex program on all vertices that receive messages - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Merge the new vertex values back into the graph - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() - // Send Messages: ------------------------------------------------------------------------------ - // Vertices that didn't receive a message above don't appear in newVerts and therefore don't - // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked - // on edges in the activeDir of vertices in newVerts - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDir))).cache() + // Receive the messages and update the vertices. + g = g.joinVertices(messages)(vprog).cache() + val oldMessages = messages + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } @@ -946,7 +944,7 @@ The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] -// Revere the edges reusing both attributes and structure +// Reverse the edges reusing both attributes and structure def reverse: EdgeRDD[ED] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md deleted file mode 100644 index 795dd82a6be0..000000000000 --- a/docs/hadoop-third-party-distributions.md +++ /dev/null @@ -1,117 +0,0 @@ ---- -layout: global -title: Third-Party Hadoop Distributions ---- - -Spark can run against all versions of Cloudera's Distribution Including Apache Hadoop (CDH) and -the Hortonworks Data Platform (HDP). There are a few things to keep in mind when using Spark -with these distributions: - -# Compile-time Hadoop Version - -When compiling Spark, you'll need to specify the Hadoop version by defining the `hadoop.version` -property. For certain versions, you will need to specify additional profiles. For more detail, -see the guide on [building with maven](building-spark.html#specifying-the-hadoop-version): - - mvn -Dhadoop.version=1.0.4 -DskipTests clean package - mvn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package - -The table below lists the corresponding `hadoop.version` code for each CDH/HDP release. Note that -some Hadoop releases are binary compatible across client versions. This means the pre-built Spark -distribution may "just work" without you needing to compile. That said, we recommend compiling with -the _exact_ Hadoop version you are running to avoid any compatibility errors. - - - - - - -
    -

    CDH Releases

    - - - - -
    ReleaseVersion code
    CDH 4.X.X (YARN mode)2.0.0-cdh4.X.X
    CDH 4.X.X2.0.0-mr1-cdh4.X.X
    -
    -

    HDP Releases

    - - - - - - - -
    ReleaseVersion code
    HDP 1.31.2.0
    HDP 1.21.1.2
    HDP 1.11.0.3
    HDP 1.01.0.3
    HDP 2.02.2.0
    -
    - -In SBT, the equivalent can be achieved by setting the the `hadoop.version` property: - - build/sbt -Dhadoop.version=1.0.4 assembly - -# Linking Applications to the Hadoop Version - -In addition to compiling Spark itself against the right version, you need to add a Maven dependency on that -version of `hadoop-client` to any Spark applications you run, so they can also talk to the HDFS version -on the cluster. If you are using CDH, you also need to add the Cloudera Maven repository. -This looks as follows in SBT: - -{% highlight scala %} -libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "" - -// If using CDH, also add Cloudera repo -resolvers += "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" -{% endhighlight %} - -Or in Maven: - -{% highlight xml %} - - - ... - - org.apache.hadoop - hadoop-client - [version] - - - - - - ... - - Cloudera repository - https://repository.cloudera.com/artifactory/cloudera-repos/ - - - - -{% endhighlight %} - -# Where to Run Spark - -As described in the [Hardware Provisioning](hardware-provisioning.html#storage-systems) guide, -Spark can run in a variety of deployment modes: - -* Using dedicated set of Spark nodes in your cluster. These nodes should be co-located with your - Hadoop installation. -* Running on the same nodes as an existing Hadoop installation, with a fixed amount memory and - cores dedicated to Spark on each node. -* Run Spark alongside Hadoop using a cluster resource manager, such as YARN or Mesos. - -These options are identical for those using CDH and HDP. - -# Inheriting Cluster Configuration - -If you plan to read and write from HDFS using Spark, there are two Hadoop configuration files that -should be included on Spark's classpath: - -* `hdfs-site.xml`, which provides default behaviors for the HDFS client. -* `core-site.xml`, which sets the default filesystem name. - -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create -configurations on-the-fly, but offer a mechanisms to download copies of them. - -To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` -to a location containing the configuration files. diff --git a/docs/index.md b/docs/index.md index d85cf12defef..ae26f97c86c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -87,10 +87,9 @@ options for deployment: in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** @@ -118,7 +117,6 @@ options for deployment: * [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware -* [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 8d9c2ba2041b..36327c6efeaf 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying @@ -56,36 +56,32 @@ provide another approach to share RDDs. ## Dynamic Resource Allocation -Spark 1.2 introduces the ability to dynamically scale the set of cluster resources allocated to -your application up and down based on the workload. This means that your application may give -resources back to the cluster if they are no longer used and request them again later when there -is demand. This feature is particularly useful if multiple applications share resources in your -Spark cluster. If a subset of the resources allocated to an application becomes idle, it can be -returned to the cluster's pool of resources and acquired by other applications. In Spark, dynamic -resource allocation is performed on the granularity of the executor and can be enabled through -`spark.dynamicAllocation.enabled`. - -This feature is currently disabled by default and available only on [YARN](running-on-yarn.html). -A future release will extend this to [standalone mode](spark-standalone.html) and -[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). Note that although Spark on -Mesos already has a similar notion of dynamic resource sharing in fine-grained mode, enabling -dynamic allocation allows your Mesos application to take advantage of coarse-grained low-latency -scheduling while sharing cluster resources efficiently. +Spark provides a mechanism to dynamically adjust the resources your application occupies based +on the workload. This means that your application may give resources back to the cluster if they +are no longer used and request them again later when there is demand. This feature is particularly +useful if multiple applications share resources in your Spark cluster. + +This feature is disabled by default and available on all coarse-grained cluster managers, i.e. +[standalone mode](spark-standalone.html), [YARN mode](running-on-yarn.html), and +[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). ### Configuration and Setup -All configurations used by this feature live under the `spark.dynamicAllocation.*` namespace. -To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true`. -Other relevant configurations are described on the -[configurations page](configuration.html#dynamic-allocation) and in the subsequent sections in -detail. +There are two requirements for using this feature. First, your application must set +`spark.dynamicAllocation.enabled` to `true`. Second, you must set up an *external shuffle service* +on each worker node in the same cluster and set `spark.shuffle.service.enabled` to true in your +application. The purpose of the external shuffle service is to allow executors to be removed +without deleting shuffle files written by them (more detail described +[below](job-scheduling.html#graceful-decommission-of-executors)). The way to set up this service +varies across cluster managers: + +In standalone mode, simply start your workers with `spark.shuffle.service.enabled` set to `true`. -Additionally, your application must use an external shuffle service. The purpose of the service is -to preserve the shuffle files written by executors so the executors can be safely removed (more -detail described [below](job-scheduling.html#graceful-decommission-of-executors)). To enable -this service, set `spark.shuffle.service.enabled` to `true`. In YARN, this external shuffle service -is implemented in `org.apache.spark.yarn.network.YarnShuffleService` that runs in each `NodeManager` -in your cluster. To start this service, follow these steps: +In Mesos coarse-grained mode, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all +slave nodes with `spark.shuffle.service.enabled` set to `true`. For instance, you may do so +through Marathon. + +In YARN mode, start the shuffle service on each `NodeManager` as follows: 1. Build Spark with the [YARN profile](building-spark.html). Skip this step if you are using a pre-packaged distribution. @@ -95,10 +91,13 @@ pre-packaged distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to -`org.apache.spark.network.yarn.YarnShuffleService`. Additionally, set all relevant -`spark.shuffle.service.*` [configurations](configuration.html). +`org.apache.spark.network.yarn.YarnShuffleService` and `spark.shuffle.service.enabled` to true. 4. Restart all `NodeManager`s in your cluster. +All other relevant configurations are optional and under the `spark.dynamicAllocation.*` and +`spark.shuffle.service.*` namespaces. For more detail, see the +[configurations page](configuration.html#dynamic-allocation). + ### Resource Allocation Policy At a high level, Spark should relinquish executors when they are no longer used and acquire diff --git a/docs/js/main.js b/docs/js/main.js index f5d66b16f7b2..2329eb8327dd 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -83,7 +83,7 @@ $(function() { // Display anchor links when hovering over headers. For documentation of the // configuration options, see the AnchorJS documentation. anchors.options = { - placement: 'left' + placement: 'right' }; anchors.add(); diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md new file mode 100644 index 000000000000..91731d78a2d4 --- /dev/null +++ b/docs/ml-advanced.md @@ -0,0 +1,13 @@ +--- +layout: global +title: Advanced topics - spark.ml +displayTitle: Advanced topics - spark.ml +--- + +# Optimization of linear methods + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. diff --git a/docs/ml-ann.md b/docs/ml-ann.md new file mode 100644 index 000000000000..c2d9bd200f62 --- /dev/null +++ b/docs/ml-ann.md @@ -0,0 +1,8 @@ +--- +layout: global +title: Multilayer perceptron classifier - spark.ml +displayTitle: Multilayer perceptron classifier - spark.ml +--- + + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#multilayer-perceptron-classifier). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md new file mode 100644 index 000000000000..d63438bf74c1 --- /dev/null +++ b/docs/ml-classification-regression.md @@ -0,0 +1,776 @@ +--- +layout: global +title: Classification and regression - spark.ml +displayTitle: Classification and regression - spark.ml +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +In `spark.ml`, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details about implementation and tuning. We also include a DataFrame API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + + +# Classification + +## Logistic regression + +Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + + > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + +**Example** + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} +
    + +
    +{% include_example python/ml/logistic_regression_with_elastic_net.py %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `DataFrame` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + + +## Decision tree classifier + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} + +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} + +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% include_example python/ml/decision_tree_classification_example.py %} + +
    + +
    + +## Random forest classifier + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% include_example python/ml/random_forest_classifier_example.py %} +
    +
    + +## Gradient-boosted tree classifier + +Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %} +
    +
    + +## Multilayer perceptron classifier + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    + +
    +{% include_example python/ml/multilayer_perceptron_classification.py %} +
    + +
    + + +## One-vs-Rest classifier (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +**Example** + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} +
    +
    + + +# Regression + +## Linear regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. + +**Example** + +The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} +
    + +
    + +{% include_example python/ml/linear_regression_with_elastic_net.py %} +
    + +
    + + +## Decision tree regression + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% include_example python/ml/decision_tree_regression_example.py %} +
    + +
    + + +## Random forest regression + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% include_example python/ml/random_forest_regressor_example.py %} +
    +
    + +## Gradient-boosted tree regression + +Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %} +
    +
    + + +## Survival regression + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    + + + +# Decision trees + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +The `spark.ml` implementation supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). +The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). + +## Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + +# Tree Ensembles + +The DataFrame API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [`spark.ml` decision trees](ml-classification-regression.html#decision-trees) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). +In this section, we demonstrate the DataFrame API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + +* support for DataFrames and ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +## Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + + +## Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +The `spark.ml` implementation supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 000000000000..440c455cd077 --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,107 @@ +--- +layout: global +title: Clustering - spark.ml +displayTitle: Clustering - spark.ml +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a +predefined number of clusters. The MLlib implementation includes a parallelized +variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method +called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). + +`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model. + +### Input Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    predictionColInt"prediction"Predicted cluster center
    + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} +
    + +
    + + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md new file mode 100644 index 000000000000..a721d55bc675 --- /dev/null +++ b/docs/ml-decision-tree.md @@ -0,0 +1,8 @@ +--- +layout: global +title: Decision trees - spark.ml +displayTitle: Decision trees - spark.ml +--- + + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#decision-trees). diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 9ff50e95fc47..303773e8038f 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -1,129 +1,8 @@ --- layout: global -title: Ensembles -displayTitle: ML - Ensembles +title: Tree ensemble methods - spark.ml +displayTitle: Tree ensemble methods - spark.ml --- -**Table of Contents** - -* This will become a table of contents (this text will be scraped). -{:toc} - -An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) -is a learning algorithm which creates a model composed of a set of other base models. -The Pipelines API supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) - -## OneVsRest - -[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. - -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. - -Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. - -### Example - -The example below demonstrates how to load the -[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. - -
    -
    -{% highlight scala %} -import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{Row, SQLContext} - -val sqlContext = new SQLContext(sc) - -// parse data into dataframe -val data = MLUtils.loadLibSVMFile(sc, - "data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) - -// instantiate multiclass learner and train -val ovr = new OneVsRest().setClassifier(new LogisticRegression) - -val ovrModel = ovr.fit(train) - -// score model on test data -val predictions = ovrModel.transform(test).select("prediction", "label") -val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)} - -// compute confusion matrix -val metrics = new MulticlassMetrics(predictionsAndLabels) -println(metrics.confusionMatrix) - -// the Iris DataSet has three classes -val numClasses = 3 - -println("label\tfpr\n") -(0 until numClasses).foreach { index => - val label = index.toDouble - println(label + "\t" + metrics.falsePositiveRate(label)) -} -{% endhighlight %} -
    -
    -{% highlight java %} - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.OneVsRest; -import org.apache.spark.ml.classification.OneVsRestModel; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - -RDD data = MLUtils.loadLibSVMFile(jsc.sc(), - "data/mllib/sample_multiclass_classification_data.txt"); - -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); -DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345); -DataFrame train = splits[0]; -DataFrame test = splits[1]; - -// instantiate the One Vs Rest Classifier -OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); - -// train the multiclass model -OneVsRestModel ovrModel = ovr.fit(train.cache()); - -// score the model on test data -DataFrame predictions = ovrModel - .transform(test) - .select("prediction", "label"); - -// obtain metrics -MulticlassMetrics metrics = new MulticlassMetrics(predictions); -Matrix confusionMatrix = metrics.confusionMatrix(); - -// output the Confusion Matrix -System.out.println("Confusion Matrix"); -System.out.println(confusionMatrix); - -// compute the false positive rate per label -System.out.println(); -System.out.println("label\tfpr\n"); - -// the Iris DataSet has three classes -int numClasses = 3; -for (int index = 0; index < numClasses; index++) { - double label = (double) index; - System.out.print(label); - System.out.print("\t"); - System.out.print(metrics.falsePositiveRate(label)); - System.out.println(); -} -{% endhighlight %} -
    -
    + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#tree-ensembles). diff --git a/docs/ml-features.md b/docs/ml-features.md index 54068debe215..677e4bfb916e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction, Transformation, and Selection - SparkML -displayTitle: ML - Features +title: Extracting, transforming and selecting features - spark.ml +displayTitle: Extracting, transforming and selecting features - spark.ml --- This section covers algorithms for working with features, roughly divided into these groups: @@ -28,186 +28,126 @@ The algorithm combines Term Frequency (TF) counts with the [hashing trick](http: **IDF**: `IDF` is an `Estimator` which fits on a dataset and produces an `IDFModel`. The `IDFModel` takes feature vectors (generally created from `HashingTF`) and scales each column. Intuitively, it down-weights columns which appear frequently in a corpus. Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term Frequency and Inverse Document Frequency. -For API details, refer to the [HashingTF API docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and the [IDF API docs](api/scala/index.html#org.apache.spark.ml.feature.IDF). In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we use `HashingTF` to hash the sentence into a feature vector. We use `IDF` to rescale the feature vectors; this generally improves performance when using text as features. Our feature vectors could then be passed to a learning algorithm.
    -{% highlight scala %} -import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} - -val sentenceData = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsData = tokenizer.transform(sentenceData) -val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) -val featurizedData = hashingTF.transform(wordsData) -val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") -val idfModel = idf.fit(featurizedData) -val rescaledData = idfModel.transform(featurizedData) -rescaledData.select("features", "label").take(3).foreach(println) -{% endhighlight %} + +Refer to the [HashingTF Scala docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and +the [IDF Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IDF) for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/TfIdfExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.IDF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsData = tokenizer.transform(sentenceData); -int numFeatures = 20; -HashingTF hashingTF = new HashingTF() - .setInputCol("words") - .setOutputCol("rawFeatures") - .setNumFeatures(numFeatures); -DataFrame featurizedData = hashingTF.transform(wordsData); -IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); -IDFModel idfModel = idf.fit(featurizedData); -DataFrame rescaledData = idfModel.transform(featurizedData); -for (Row r : rescaledData.select("features", "label").take(3)) { - Vector features = r.getAs(0); - Double label = r.getDouble(1); - System.out.println(features); -} -{% endhighlight %} + +Refer to the [HashingTF Java docs](api/java/org/apache/spark/ml/feature/HashingTF.html) and the +[IDF Java docs](api/java/org/apache/spark/ml/feature/IDF.html) for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaTfIdfExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import HashingTF, IDF, Tokenizer - -sentenceData = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsData = tokenizer.transform(sentenceData) -hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) -featurizedData = hashingTF.transform(wordsData) -idf = IDF(inputCol="rawFeatures", outputCol="features") -idfModel = idf.fit(featurizedData) -rescaledData = idfModel.transform(featurizedData) -for features_label in rescaledData.select("features", "label").take(3): - print(features_label) -{% endhighlight %} + +Refer to the [HashingTF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.HashingTF) and +the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for more details on the API. + +{% include_example python/ml/tf_idf_example.py %}
    ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
    -{% highlight scala %} -import org.apache.spark.ml.feature.Word2Vec - -// Input data: Each row is a bag of words from a sentence or document. -val documentDF = sqlContext.createDataFrame(Seq( - "Hi I heard about Spark".split(" "), - "I wish Java could use case classes".split(" "), - "Logistic regression models are neat".split(" ") -).map(Tuple1.apply)).toDF("text") - -// Learn a mapping from words to Vectors. -val word2Vec = new Word2Vec() - .setInputCol("text") - .setOutputCol("result") - .setVectorSize(3) - .setMinCount(0) -val model = word2Vec.fit(documentDF) -val result = model.transform(documentDF) -result.select("result").take(3).foreach(println) -{% endhighlight %} + +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/Word2VecExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; - -JavaSparkContext jsc = ... -SQLContext sqlContext = ... - -// Input data: Each row is a bag of words from a sentence or document. -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) -}); -DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); - -// Learn a mapping from words to Vectors. -Word2Vec word2Vec = new Word2Vec() - .setInputCol("text") - .setOutputCol("result") - .setVectorSize(3) - .setMinCount(0); -Word2VecModel model = word2Vec.fit(documentDF); -DataFrame result = model.transform(documentDF); -for (Row r: result.select("result").take(3)) { - System.out.println(r); -} -{% endhighlight %} + +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaWord2VecExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import Word2Vec -# Input data: Each row is a bag of words from a sentence or document. -documentDF = sqlContext.createDataFrame([ - ("Hi I heard about Spark".split(" "), ), - ("I wish Java could use case classes".split(" "), ), - ("Logistic regression models are neat".split(" "), ) -], ["text"]) -# Learn a mapping from words to Vectors. -word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result") -model = word2Vec.fit(documentDF) -result = model.transform(documentDF) -for feature in result.select("result").take(3): - print(feature) -{% endhighlight %} +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + +{% include_example python/ml/word2vec_example.py %} +
    +
    + +## CountVectorizer + +`CountVectorizer` and `CountVectorizerModel` aim to help convert a collection of text documents + to vectors of token counts. When an a-priori dictionary is not available, `CountVectorizer` can + be used as an `Estimator` to extract the vocabulary and generates a `CountVectorizerModel`. The + model produces sparse representations for the documents over the vocabulary, which can then be + passed to other algorithms like LDA. + + During the fitting process, `CountVectorizer` will select the top `vocabSize` words ordered by + term frequency across the corpus. An optional parameter "minDF" also affect the fitting process + by specifying the minimum number (or fraction if < 1.0) of documents a term must appear in to be + included in the vocabulary. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `texts`: + +~~~~ + id | texts +----|---------- + 0 | Array("a", "b", "c") + 1 | Array("a", "b", "b", "c", "a") +~~~~ + +each row in`texts` is a document of type Array[String]. +Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c), +then the output column "vector" after transformation contains: + +~~~~ + id | texts | vector +----|---------------------------------|--------------- + 0 | Array("a", "b", "c") | (3,[0,1,2],[1.0,1.0,1.0]) + 1 | Array("a", "b", "b", "c", "a") | (3,[0,1,2],[2.0,2.0,1.0]) +~~~~ + +each vector represents the token counts of the document over the vocabulary. + +
    +
    + +Refer to the [CountVectorizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizer) +and the [CountVectorizerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizerModel) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/CountVectorizerExample.scala %} +
    + +
    + +Refer to the [CountVectorizer Java docs](api/java/org/apache/spark/ml/feature/CountVectorizer.html) +and the [CountVectorizerModel Java docs](api/java/org/apache/spark/ml/feature/CountVectorizerModel.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %}
    @@ -217,247 +157,199 @@ for feature in result.select("result").take(3): [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
    -{% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) -{% endhighlight %} +Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) +and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} -{% endhighlight %} + +Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenizer.html) +and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import Tokenizer -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -{% endhighlight %} +Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Tokenizer) and +the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) +for more details on the API. + +{% include_example python/ml/tokenizer_example.py %}
    +## StopWordsRemover +[Stop words](https://en.wikipedia.org/wiki/Stop_words) are words which +should be excluded from the input, typically because the words appear +frequently and don't carry as much meaning. -## $n$-gram +`StopWordsRemover` takes as input a sequence of strings (e.g. the output +of a [Tokenizer](ml-features.html#tokenizer)) and drops all the stop +words from the input sequences. The list of stopwords is specified by +the `stopWords` parameter. We provide [a list of stop +words](http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words) by +default, accessible by calling `getStopWords` on a newly instantiated +`StopWordsRemover` instance. A boolean parameter `caseSensitive` indicates +if the matches should be case sensitive (false by default). -An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. +**Examples** -`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +Assume that we have the following DataFrame with columns `id` and `raw`: + +~~~~ + id | raw +----|---------- + 0 | [I, saw, the, red, baloon] + 1 | [Mary, had, a, little, lamb] +~~~~ + +Applying `StopWordsRemover` with `raw` as the input column and `filtered` as the output +column, we should get the following: + +~~~~ + id | raw | filtered +----|-----------------------------|-------------------- + 0 | [I, saw, the, red, baloon] | [saw, red, baloon] + 1 | [Mary, had, a, little, lamb]|[Mary, little, lamb] +~~~~ + +In `filtered`, the stop words "I", "the", "had", and "a" have been +filtered out. -
    -
    -[`NGram`](api/scala/index.html#org.apache.spark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). +Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %} +
    + +
    + +Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) +for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %} +
    + +
    -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") +Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) +for more details on the API. -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %} +
    +
    + +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. + +
    + +
    + +Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    -[`NGram`](api/java/org/apache/spark/ml/feature/NGram.html) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). - -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0D, Lists.newArrayList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1D, Lists.newArrayList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2D, Lists.newArrayList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    -[`NGram`](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). - -{% highlight python %} -from pyspark.ml.feature import NGram +Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) +for more details on the API. -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
    -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") +Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %} +
    + +
    + +Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %} +
    + +
    + +Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) +for more details on the API. + +{% include_example python/ml/binarizer_example.py %} +
    +
    + +## PCA -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) +[PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +
    +
    + +Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} + +Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import Binarizer -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) +for more details on the API. + +{% include_example python/ml/pca_example.py %}
    @@ -467,80 +359,59 @@ for binarized_feature, in binarizedFeatures.collect():
    -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} + +Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} + +Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) +for more details on the API. + +{% include_example python/ml/polynomial_expansion_example.py %} +
    +
    + +## Discrete Cosine Transform (DCT) + +The [Discrete Cosine +Transform](https://en.wikipedia.org/wiki/Discrete_cosine_transform) +transforms a length $N$ real-valued sequence in the time domain into +another length $N$ real-valued sequence in the frequency domain. A +[DCT](api/scala/index.html#org.apache.spark.ml.feature.DCT) class +provides this functionality, implementing the +[DCT-II](https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II) +and scaling the result by $1/\sqrt{2}$ such that the representing matrix +for the transform is unitary. No shift is applied to the transformed +sequence (e.g. the $0$th element of the transformed sequence is the +$0$th DCT coefficient and _not_ the $N/2$th). + +
    +
    + +Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %} +
    + +
    + +Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
    @@ -549,7 +420,11 @@ for expanded in polyDF.select("polyFeatures").take(3): `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string values. +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, +you can set the input column with `setInputCol`. **Examples** @@ -584,174 +459,165 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. +Additionaly, there are two strategies regarding how `StringIndexer` will handle +unseen labels when you have fit a `StringIndexer` on one dataset and then use it +to transform another: + +- throw an exception (which is the default) +- skip the row containing the unseen label entirely + +**Examples** + +Let's go back to our previous example but this time reuse our previously defined +`StringIndexer` on the following dataset: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | d +~~~~ + +If you've not set how `StringIndexer` handles unseen labels or set it to +"error", an exception will be thrown. +However, if you had called `setHandleInvalid("skip")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 +~~~~ + +Notice that the row containing "d" does not appear. +
    -[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input -column name and an output column name. +Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %} +
    + +
    + +Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %} +
    + +
    -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer +Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) +for more details on the API. + +{% include_example python/ml/string_indexer_example.py %} +
    +
    + + +## IndexToString + +Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices +back to a column containing the original labels as strings. The common use case +is to produce indices from labels with `StringIndexer`, train a model with those +indices and retrieve the original labels from the column of predicted indices +with `IndexToString`. However, you are free to supply your own labels. + +**Examples** + +Building on the `StringIndexer` example, let's assume we have the following +DataFrame with columns `id` and `categoryIndex`: + +~~~~ + id | categoryIndex +----|--------------- + 0 | 0.0 + 1 | 2.0 + 2 | 1.0 + 3 | 0.0 + 4 | 0.0 + 5 | 1.0 +~~~~ + +Applying `IndexToString` with `categoryIndex` as the input column, +`originalCategory` as the output column, we are able to retrieve our original +labels (they will be inferred from the columns' metadata): + +~~~~ + id | categoryIndex | originalCategory +----|---------------|----------------- + 0 | 0.0 | a + 1 | 2.0 | b + 2 | 1.0 | c + 3 | 0.0 | a + 4 | 0.0 | a + 5 | 1.0 | c +~~~~ + +
    +
    + +Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %} -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %}
    -[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column -name and an output column name. - -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} + +Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %} +
    -[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input -column name and an output column name. +Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString) +for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer +{% include_example python/ml/index_to_string_example.py %} -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %}
    ## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
    -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) -{% endhighlight %} + +Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -{% endhighlight %} + +Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) +Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) +for more details on the API. -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -767,74 +633,31 @@ It can both automatically decide which features are categorical and convert orig Indexing categorical features allows algorithms such as Decision Trees and Tree Ensembles to treat categorical features appropriately, improving performance. -Please refer to the [VectorIndexer API docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details. - In the example below, we read in a dataset of labeled points and then use `VectorIndexer` to decide which features should be treated as categorical. We transform the categorical feature values to their indices. This transformed data could then be passed to algorithms such as `DecisionTreeRegressor` that handle categorical features.
    -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) +Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) +for more details on the API. -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    -{% highlight java %} -import java.util.Map; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.sql.DataFrame; - -JavaRDD rdd = MLUtils.loadLibSVMFile(sc.sc(), - "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} + +Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    -{% highlight python %} -from pyspark.ml.feature import VectorIndexer -from pyspark.mllib.util import MLUtils -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) +Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) +for more details on the API. -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -846,66 +669,28 @@ indexedData = indexerModel.transform(data) The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
    -
    -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer -import org.apache.spark.mllib.util.MLUtils - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +
    -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) +Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) +for more details on the API. -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    -
    -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.sql.DataFrame; - -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +
    -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); +Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) +for more details on the API. -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    -
    -{% highlight python %} -from pyspark.mllib.util import MLUtils -from pyspark.ml.feature import Normalizer - -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +
    -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) +Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) +for more details on the API. -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -917,79 +702,74 @@ lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) * `withStd`: True by default. Scales the data to unit standard deviation. * `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. +`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. -More details can be found in the API docs for -[StandardScaler](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) and -[StandardScalerModel](api/scala/index.html#org.apache.spark.ml.feature.StandardScalerModel). - The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
    -
    -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler -import org.apache.spark.mllib.util.MLUtils - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) +
    -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) +Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) +for more details on the API. -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    -
    -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.sql.DataFrame; +
    + +Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) +for more details on the API. -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %} +
    + +
    -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); +Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) +for more details on the API. -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    +
    + +## MinMaxScaler + +`MinMaxScaler` transforms a dataset of `Vector` rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters: -
    -{% highlight python %} -from pyspark.mllib.util import MLUtils -from pyspark.ml.feature import StandardScaler +* `min`: 0.0 by default. Lower bound after transformation, shared by all features. +* `max`: 1.0 by default. Upper bound after transformation, shared by all features. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) +`MinMaxScaler` computes summary statistics on a data set and produces a `MinMaxScalerModel`. The model can then transform each feature individually such that it is in the given range. -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) +The rescaled value for a feature E is calculated as, +`\begin{equation} + Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min +\end{equation}` +For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)` -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. + +The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1]. + +
    +
    + +Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScaler) +and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %} +
    + +
    + +Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) +and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1008,75 +788,28 @@ More details can be found in the API docs for [Bucketizer](api/scala/index.html# The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
    -
    -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +
    -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) +Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) +for more details on the API. -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    -
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Lists.newArrayList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); +
    -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); +Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) +for more details on the API. -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    -
    -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) +
    -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") +Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) +for more details on the API. -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1100,71 +833,90 @@ v_N \end{pmatrix} \]` -[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter: +This example below demonstrates how to transform vectors using a transforming vector value. -* `scalingVec`: the transforming vector. +
    +
    -This example below demonstrates how to transform vectors using a transforming vector value. +Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %} +
    + +
    + +Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %} +
    + +
    + +Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) +for more details on the API. + +{% include_example python/ml/elementwise_product_example.py %} +
    +
    + +## SQLTransformer + +`SQLTransformer` implements the transformations which are defined by SQL statement. +Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +where `"__THIS__"` represents the underlying table of the input dataset. +The select clause specifies the fields, constants, and expressions to display in +the output, it can be any select clause that Spark SQL supports. Users can also +use Spark SQL built-in function and UDFs to operate on these selected columns. +For example, `SQLTransformer` supports statements like: + +* `SELECT a, a + b AS a_b FROM __THIS__` +* `SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5` +* `SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b` + +**Examples** + +Assume that we have the following DataFrame with columns `id`, `v1` and `v2`: + +~~~~ + id | v1 | v2 +----|-----|----- + 0 | 1.0 | 3.0 + 2 | 2.0 | 5.0 +~~~~ + +This is the output of the `SQLTransformer` with statement `"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"`: + +~~~~ + id | v1 | v2 | v3 | v4 +----|-----|-----|-----|----- + 0 | 1.0 | 3.0 | 4.0 | 3.0 + 2 | 2.0 | 5.0 | 7.0 |10.0 +~~~~
    -
    -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -val transformedData = transformer.transform(dataFrame) - -{% endhighlight %} -
    - -
    -{% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -DataFrame transformedData = transformer.transform(dataFrame); - -{% endhighlight %} +
    + +Refer to the [SQLTransformer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.SQLTransformer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/SQLTransformerExample.scala %} +
    + +
    + +Refer to the [SQLTransformer Java docs](api/java/org/apache/spark/ml/feature/SQLTransformer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java %} +
    + +
    + +Refer to the [SQLTransformer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.SQLTransformer) for more details on the API. + +{% include_example python/ml/sql_transformer.py %}
    @@ -1206,81 +958,266 @@ output column to `features`, after transformation we should get the following Da
    -[`VectorAssembler`](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) takes an array -of input column names and an output column name. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler +Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) +for more details on the API. -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    -[`VectorAssembler`](api/java/org/apache/spark/ml/feature/VectorAssembler.html) takes an array -of input column names and an output column name. - -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    -[`VectorAssembler`](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) takes a list -of input column names and an output column name. +Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) +for more details on the API. + +{% include_example python/ml/vector_assembler_example.py %} +
    +
    + +## QuantileDiscretizer + +`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned +categorical features. +The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. +The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. +This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may +find fewer depending on the data sample values. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler +Note that the result may be different every time you run it, since the sample strategy behind it is +non-deterministic. -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`: + +~~~ + id | hour +----|------ + 0 | 18.0 +----|------ + 1 | 19.0 +----|------ + 2 | 8.0 +----|------ + 3 | 5.0 +----|------ + 4 | 2.2 +~~~ + +`hour` is a continuous feature with `Double` type. We want to turn the continuous feature into +categorical one. Given `numBuckets = 3`, we should get the following DataFrame: + +~~~ + id | hour | result +----|------|------ + 0 | 18.0 | 2.0 +----|------|------ + 1 | 19.0 | 2.0 +----|------|------ + 2 | 8.0 | 1.0 +----|------|------ + 3 | 5.0 | 1.0 +----|------|------ + 4 | 2.2 | 0.0 +~~~ + +
    +
    + +Refer to the [QuantileDiscretizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.QuantileDiscretizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala %} +
    + +
    + +Refer to the [QuantileDiscretizer Java docs](api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %}
    # Feature Selectors +## VectorSlicer + +`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a +sub-array of the original features. It is useful for extracting features from a vector column. + +`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column +whose values are selected via those indices. There are two types of indices, + + 1. Integer indices that represents the indices into the vector, `setIndices()`; + + 2. String indices that represents the names of features into the vector, `setNames()`. + *This requires the vector column to have an `AttributeGroup` since the implementation matches on + the name field of an `Attribute`.* + +Specification by integer and string are both acceptable. Moreover, you can use integer index and +string name simultaneously. At least one feature must be selected. Duplicate features are not +allowed, so there can be no overlap between selected indices and names. Note that if names of +features are selected, an exception will be threw out when encountering with empty input attributes. + +The output vector will order features with the selected indices first (in the order given), +followed by the selected names (in the order given). + +**Examples** + +Suppose that we have a DataFrame with the column `userFeatures`: + +~~~ + userFeatures +------------------ + [0.0, 10.0, 0.5] +~~~ + +`userFeatures` is a vector column that contains three user features. Assuming that the first column +of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected. +The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector +column named `features`: + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] +~~~ + +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] + ["f1", "f2", "f3"] | ["f2", "f3"] +~~~ + +
    +
    + +Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %} +
    + +
    + +Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %} +
    +
    + +## RFormula + +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: + +~~~ +id | country | hour | clicked +---|---------|------|--------- + 7 | "US" | 18 | 1.0 + 8 | "CA" | 12 | 0.0 + 9 | "NZ" | 15 | 0.0 +~~~ + +If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to +predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame: + +~~~ +id | country | hour | clicked | features | label +---|---------|------|---------|------------------|------- + 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0 + 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0 + 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0 +~~~ + +
    +
    + +Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %} +
    + +
    + +Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %} +
    + +
    + +Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) +for more details on the API. + +{% include_example python/ml/rformula_example.py %} +
    +
    + +## ChiSqSelector + +`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with +categorical features. ChiSqSelector orders features based on a +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) +from the class, and then filters (selects) the top features which the class label depends on the +most. This is akin to yielding the features with the most predictive power. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `clicked`, which is used as +our target to be predicted: + +~~~ +id | features | clicked +---|-----------------------|--------- + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 +~~~ + +If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` chosen as the most useful feature: + +~~~ +id | features | clicked | selectedFeatures +---|-----------------------|---------|------------------ + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 | [1.0] + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 | [0.0] + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 | [0.1] +~~~ + +
    +
    + +Refer to the [ChiSqSelector Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala %} +
    + +
    + +Refer to the [ChiSqSelector Java docs](api/java/org/apache/spark/ml/feature/ChiSqSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %} +
    +
    diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e98db0..44a316a07dfe 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -1,8 +1,10 @@ --- layout: global -title: Spark ML Programming Guide +title: "Overview: estimators, transformers and pipelines - spark.ml" +displayTitle: "Overview: estimators, transformers and pipelines - spark.ml" --- + `\[ \newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} @@ -21,75 +23,78 @@ title: Spark ML Programming Guide \]` -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -Guides for sub-packages of `spark.ml` include: - -* [Feature Extraction, Transformation, and Selection](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API - +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. -**Table of Contents** +**Table of contents** * This will become a table of contents (this text will be scraped). {:toc} -# Main Concepts -Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. +# Main concepts in Pipelines -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. -E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. + +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. * **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. -E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. * **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. -E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Param`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. -## ML Dataset +## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. `DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." -## ML Algorithms +## Pipeline components ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. For example: -* A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. -* A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset. +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. -For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. -### Properties of ML Algorithms +### Properties of pipeline components -`Transformer`s and `Estimator`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). @@ -102,15 +107,16 @@ E.g., a simple text document processing workflow might include several stages: * Convert each document's words into a numerical feature vector. * Learn a prediction model using the feature vectors and labels. -Spark ML represents such a workflow as a [`Pipeline`](api/scala/index.html#org.apache.spark.ml.Pipeline), -which consists of a sequence of [`PipelineStage`s](api/scala/index.html#org.apache.spark.ml.PipelineStage) (`Transformer`s and `Estimator`s) to be run in a specific order. We will use this simple workflow as a running example in this section. +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. -### How It Works +### How it works A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. -These stages are run in order, and the input dataset is modified as it passes through each stage. -For `Transformer` stages, the `transform()` method is called on the dataset. -For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the dataset. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. @@ -126,14 +132,17 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. -The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. -The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. -The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` method on the dataset before passing the dataset to the next stage. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. -Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel` which is a `Transformer`. This `PipelineModel` is used at *test time*; the figure below illustrates this usage. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage.

    In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. -When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed through the `Pipeline` in order. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. Each stage's `transform()` method updates the dataset and passes it to the next stage. `Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. @@ -154,41 +164,47 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. -A [`Param`](api/scala/index.html#org.apache.spark.ml.param.Param) is a named parameter with self-contained documentation. -A [`ParamMap`](api/scala/index.html#org.apache.spark.ml.param.ParamMap) is a set of (parameter, value) pairs. +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. There are two main ways to pass parameters to an algorithm: -1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. This API resembles the API used in MLlib. +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. 2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm Guides - -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -**Pipelines API Algorithm Guides** - -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Ensembles](ml-ensembles.html) - -**Algorithms in `spark.ml`** +## Saving and Loading Pipelines -* [Linear methods with elastic net regularization](ml-linear-methods.html) +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. -# Code Examples +# Code examples This section gives code examples illustrating the functionality discussed above. -There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the [API Documentation](api/scala/index.html#org.apache.spark.ml.package). Spark ML algorithms are currently wrappers for MLlib algorithms, and the [MLlib programming guide](mllib-guide.html) has details on specific algorithms. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. ## Example: Estimator, Transformer, and Param @@ -198,26 +214,18 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.

    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row -val conf = new SparkConf().setAppName("SimpleParamsExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into DataFrames, where it uses the case class metadata to infer the schema. -val training = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() @@ -229,7 +237,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training.toDF) +val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -239,8 +247,8 @@ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) -paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -248,58 +256,52 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training.toDF, paramMapCombined) +val model2 = lr.fit(training, paramMapCombined) println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. -val test = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test.toDF) +model2.transform(test) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => println(s"($features, $label) -> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; -SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -319,14 +321,14 @@ LogisticRegressionModel model1 = lr.fit(training); System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. -paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -335,11 +337,11 @@ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. -List localTest = Lists.newArrayList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -351,7 +353,68 @@ for (Row r: results.select("features", "label", "myProbability", "prediction").c + ", prediction=" + r.get(3)); } -jsc.stop(); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params + +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training, paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + {% endhighlight %}
    @@ -365,30 +428,19 @@ This example follows the simple text document `Pipeline` illustrated in the figu
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) +import org.apache.spark.sql.Row -// Set up contexts. Import implicit conversions to DataFrame from sqlContext. -val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0))) +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -405,33 +457,41 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training.toDF) +val model = pipeline.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// now we can optionally save the fitted pipeline to disk +model.save("/tmp/spark-logistic-regression-model") + +// we can also save this unfit pipeline to disk +pipeline.save("/tmp/unfit-lr-model") + +// and load it back in during production +val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -440,7 +500,6 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -472,18 +531,13 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -// Set up contexts. -SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -503,12 +557,12 @@ Pipeline pipeline = new Pipeline() PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. DataFrame predictions = model.transform(test); @@ -517,28 +571,23 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    {% highlight python %} -from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row, SQLContext +from pyspark.sql import Row -sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlContext = SQLContext(sc) - -# Prepare training documents, which are labeled. +# Prepare training documents from a list of (id, text, label) tuples. LabeledDocument = Row("id", "text", "label") -training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ - .map(lambda x: LabeledDocument(*x)).toDF() +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") @@ -549,13 +598,12 @@ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) -# Prepare test documents, which are unlabeled. -Document = Row("id", "text") -test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) # Make predictions on test documents and print columns of interest. prediction = model.transform(test) @@ -563,20 +611,26 @@ selected = prediction.select("id", "text", "prediction") for row in selected.collect(): print(row) -sc.stop() {% endhighlight %}
    -## Example: Model Selection via Cross-Validation +## Example: model selection via cross-validation An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +method in each of these evaluators. + The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. `CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. @@ -593,39 +647,29 @@ However, it is also a well-established method for choosing parameters which is m
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) - -val conf = new SparkConf().setAppName("CrossValidatorExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0), - LabeledDocument(4L, "b spark who", 1.0), - LabeledDocument(5L, "g d a y", 0.0), - LabeledDocument(6L, "spark fly", 1.0), - LabeledDocument(7L, "was mapreduce", 0.0), - LabeledDocument(8L, "e spark program", 1.0), - LabeledDocument(9L, "a e c l", 0.0), - LabeledDocument(10L, "spark compile", 1.0), - LabeledDocument(11L, "hadoop software", 0.0))) +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -639,12 +683,6 @@ val lr = new LogisticRegression() val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -652,37 +690,45 @@ val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build() -crossval.setEstimatorParamMaps(paramGrid) -crossval.setNumFolds(2) // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training.toDF) +val cvModel = cv.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test.toDF) +cvModel.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") -} + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -695,7 +741,6 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -727,12 +772,9 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -744,8 +786,8 @@ List localTraining = Lists.newArrayList( new LabeledDocument(8L, "e spark program", 1.0), new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -761,12 +803,6 @@ LogisticRegression lr = new LogisticRegression() Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -774,19 +810,28 @@ ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .build(); -crossval.setEstimatorParamMaps(paramGrid); -crossval.setNumFolds(2); // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = crossval.fit(training); +CrossValidatorModel cvModel = cv.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). DataFrame predictions = cvModel.transform(test); @@ -795,40 +840,115 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    -# Dependencies +## Example: model selection via train validation split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} + +// Prepare training and test data. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() -Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) -Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) -# Migration Guide +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() -## From 1.3 to 1.4 +{% endhighlight %} +
    -Several major API changes occurred, including: -* `Param` and other APIs for specifying parameters -* `uid` unique IDs for Pipeline components -* Reorganization of certain classes -Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.DataFrame; -However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); -## From 1.2 to 1.3 +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; -The main API changes are from Spark SQL. We list the most important changes here: +LinearRegression lr = new LinearRegression(); -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. -* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. -* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); -Other changes were in `LogisticRegression`: +{% endhighlight %} +
    -* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). -* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. +
    \ No newline at end of file diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 1ac83d94c9e8..a8754835cab9 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,129 +1,8 @@ --- layout: global -title: Linear Methods - ML -displayTitle: ML - Linear Methods +title: Linear methods - spark.ml +displayTitle: Linear methods - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: -`\[ -\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. -\]` -By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. - -**Examples** - -
    - -
    - -{% highlight scala %} - -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.mllib.util.MLUtils - -// Load training data -val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - -// Fit the model -val lrModel = lr.fit(training) - -// Print the weights and intercept for logistic regression -println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") - -{% endhighlight %} - -
    - -
    - -{% highlight java %} - -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class LogisticRegressionWithElasticNetExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("Logistic Regression with Elastic Net Example"); - - SparkContext sc = new SparkContext(conf); - SQLContext sql = new SQLContext(sc); - String path = "sample_libsvm_data.txt"; - - // Load training data - DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); - - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - - // Fit the model - LogisticRegressionModel lrModel = lr.fit(training); - - // Print the weights and intercept for logistic regression - System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); - } -} -{% endhighlight %} -
    - -
    - -{% highlight python %} - -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) - -# Fit the model -lrModel = lr.fit(training) - -# Print the weights and intercept for logistic regression -print("Weights: " + str(lrModel.weights)) -print("Intercept: " + str(lrModel.intercept)) -{% endhighlight %} - -
    - -
    - -### Optimization - -The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. + > This section has been moved into the + [classification and regression section](ml-classification-regression.html). diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md new file mode 100644 index 000000000000..856ceb2f4e7f --- /dev/null +++ b/docs/ml-survival-regression.md @@ -0,0 +1,8 @@ +--- +layout: global +title: Survival Regression - spark.ml +displayTitle: Survival Regression - spark.ml +--- + + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#survival-regression). diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 0210950b8990..aaf8bd465c9a 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,10 +1,10 @@ --- layout: global -title: Classification and Regression - MLlib -displayTitle: MLlib - Classification and Regression +title: Classification and Regression - spark.mllib +displayTitle: Classification and Regression - spark.mllib --- -MLlib supports various methods for +The `spark.mllib` package supports various methods for [binary classification](http://en.wikipedia.org/wiki/Binary_classification), [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification), and diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index bb875ae2ae6c..93cd0c1c61ae 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,28 +1,28 @@ --- layout: global -title: Clustering - MLlib -displayTitle: MLlib - Clustering +title: Clustering - spark.mllib +displayTitle: Clustering - spark.mllib --- -Clustering is an unsupervised learning problem whereby we aim to group subsets +[Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets of entities with one another based on some notion of similarity. Clustering is often used for exploratory analysis and/or as a component of a hierarchical -supervised learning pipeline (in which distinct classifiers or regression +[supervised learning](https://en.wikipedia.org/wiki/Supervised_learning) pipeline (in which distinct classifiers or regression models are trained for each cluster). -MLlib supports the following models: +The `spark.mllib` package supports the following models: * Table of contents {:toc} ## K-means -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +[K-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a -predefined number of clusters. The MLlib implementation includes a parallelized +predefined number of clusters. The `spark.mllib` implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in `spark.mllib` has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -47,6 +47,8 @@ into two clusters. The number of desired clusters is passed to the algorithm. We Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. +Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} import org.apache.spark.mllib.linalg.Vectors @@ -77,6 +79,8 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: +Refer to the [`KMeans` Java docs](api/java/org/apache/spark/mllib/clustering/KMeans.html) and [`KMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/KMeansModel.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; @@ -132,6 +136,8 @@ data into two clusters. The number of desired clusters is passed to the algorith Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. +Refer to the [`KMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.KMeans) and [`KMeansModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.KMeansModel) for more details on the API. + {% highlight python %} from pyspark.mllib.clustering import KMeans, KMeansModel from numpy import array @@ -165,7 +171,7 @@ sameModel = KMeansModel.load(sc, "myModelPath") A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the +each with its own probability. The `spark.mllib` implementation uses the [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) algorithm to induce the maximum-likelihood model given a set of samples. The implementation has the following parameters: @@ -184,6 +190,8 @@ In the following example after loading and parsing data, we use a object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then output the parameters of the mixture model. +Refer to the [`GaussianMixture` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) and [`GaussianMixtureModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixtureModel) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.clustering.GaussianMixture import org.apache.spark.mllib.clustering.GaussianMixtureModel @@ -216,6 +224,8 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: +Refer to the [`GaussianMixture` Java docs](api/java/org/apache/spark/mllib/clustering/GaussianMixture.html) and [`GaussianMixtureModel` Java docs](api/java/org/apache/spark/mllib/clustering/GaussianMixtureModel.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; @@ -268,6 +278,8 @@ In the following example after loading and parsing data, we use a object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then output the parameters of the mixture model. +Refer to the [`GaussianMixture` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) and [`GaussianMixtureModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixtureModel) for more details on the API. + {% highlight python %} from pyspark.mllib.clustering import GaussianMixture from numpy import array @@ -296,13 +308,13 @@ graph given pairwise similarties as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. -MLlib includes an implementation of PIC using GraphX as its backend. +`spark.mllib` includes an implementation of PIC using GraphX as its backend. It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. The similarities must be nonnegative. PIC assumes that the similarity measure is symmetric. A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. If a pair is missing from input, their similarity is treated as zero. -MLlib's PIC implementation takes the following (hyper-)parameters: +`spark.mllib`'s PIC implementation takes the following (hyper-)parameters: * `k`: number of clusters * `maxIterations`: maximum number of power iterations @@ -311,7 +323,7 @@ MLlib's PIC implementation takes the following (hyper-)parameters: **Examples** -In the following, we show code snippets to demonstrate how to use PIC in MLlib. +In the following, we show code snippets to demonstrate how to use PIC in `spark.mllib`.
    @@ -324,6 +336,8 @@ Calling `PowerIterationClustering.run` returns a [`PowerIterationClusteringModel`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel), which contains the computed clustering assignments. +Refer to the [`PowerIterationClustering` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) and [`PowerIterationClusteringModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors @@ -365,6 +379,8 @@ Calling `PowerIterationClustering.run` returns a [`PowerIterationClusteringModel`](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) which contains the computed clustering assignments. +Refer to the [`PowerIterationClustering` Java docs](api/java/org/apache/spark/mllib/clustering/PowerIterationClustering.html) and [`PowerIterationClusteringModel` Java docs](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) for details on the API. + {% highlight java %} import scala.Tuple2; import scala.Tuple3; @@ -411,6 +427,8 @@ Calling `PowerIterationClustering.run` returns a [`PowerIterationClusteringModel`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering), which contains the computed clustering assignments. +Refer to the [`PowerIterationClustering` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) and [`PowerIterationClusteringModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClusteringModel) for more details on the API. + {% highlight python %} from __future__ import print_function from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel @@ -438,28 +456,129 @@ sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") is a topic model which infers topics from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: -* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. -* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. -* Rather than estimating a clustering using a traditional distance, LDA uses a function based - on a statistical model of how text documents are generated. - -LDA takes in a collection of documents as vectors of word counts. -It supports different inference algorithms via `setOptimizer` function. EMLDAOptimizer learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) -on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: +* Topics correspond to cluster centers, and documents correspond to +examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature +vectors are vectors of word counts (bag of words). +* Rather than estimating a clustering using a traditional distance, LDA +uses a function based on a statistical model of how text documents are +generated. -* Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each non empty document in the training set, LDA gives a probability distribution over topics. (EM only). Note that for empty documents, we don't create the topic distributions. (EM only) +LDA supports different inference algorithms via `setOptimizer` function. +`EMLDAOptimizer` learns clustering using +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function and yields comprehensive results, while +`OnlineLDAOptimizer` uses iterative mini-batch sampling for [online +variational +inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) +and is generally memory friendly. -LDA takes the following parameters: +LDA takes in a collection of documents as vectors of word counts and the +following parameters (set using the builder pattern): * `k`: Number of topics (i.e., cluster centers) -* `maxIterations`: Limit on the number of iterations of EM used for learning -* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. -* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. -* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. - -*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet -support prediction on new documents, and it does not have a Python API. These will be added in the future. +* `optimizer`: Optimizer to use for learning the LDA model, either +`EMLDAOptimizer` or `OnlineLDAOptimizer` +* `docConcentration`: Dirichlet parameter for prior over documents' +distributions over topics. Larger values encourage smoother inferred +distributions. +* `topicConcentration`: Dirichlet parameter for prior over topics' +distributions over terms (words). Larger values encourage smoother +inferred distributions. +* `maxIterations`: Limit on the number of iterations. +* `checkpointInterval`: If using checkpointing (set in the Spark +configuration), this parameter specifies the frequency with which +checkpoints will be created. If `maxIterations` is large, using +checkpointing can help reduce shuffle file sizes on disk and help with +failure recovery. + + +All of `spark.mllib`'s LDA models support: + +* `describeTopics`: Returns topics as arrays of most important terms and +term weights +* `topicsMatrix`: Returns a `vocabSize` by `k` matrix where each column +is a topic + +*Note*: LDA is still an experimental feature under active development. +As a result, certain features are only available in one of the two +optimizers / models generated by the optimizer. Currently, a distributed +model can be converted into a local model, but not vice-versa. + +The following discussion will describe each optimizer/model pair +separately. + +**Expectation Maximization** + +Implemented in +[`EMLDAOptimizer`](api/scala/index.html#org.apache.spark.mllib.clustering.EMLDAOptimizer) +and +[`DistributedLDAModel`](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel). + +For the parameters provided to `LDA`: + +* `docConcentration`: Only symmetric priors are supported, so all values +in the provided `k`-dimensional vector must be identical. All values +must also be $> 1.0$. Providing `Vector(-1)` results in default behavior +(uniform `k` dimensional vector with value $(50 / k) + 1$ +* `topicConcentration`: Only symmetric priors supported. Values must be +$> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. +* `maxIterations`: The maximum number of EM iterations. + +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + +`EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only +the inferred topics but also the full training corpus and topic +distributions for each document in the training corpus. A +`DistributedLDAModel` supports: + + * `topTopicsPerDocument`: The top topics and their weights for + each document in the training corpus + * `topDocumentsPerTopic`: The top documents for each topic and + the corresponding weight of the topic in the documents. + * `logPrior`: log probability of the estimated topics and + document-topic distributions given the hyperparameters + `docConcentration` and `topicConcentration` + * `logLikelihood`: log likelihood of the training corpus, given the + inferred topics and document-topic distributions + +**Online Variational Bayes** + +Implemented in +[`OnlineLDAOptimizer`](api/scala/org/apache/spark/mllib/clustering/OnlineLDAOptimizer.html) +and +[`LocalLDAModel`](api/scala/org/apache/spark/mllib/clustering/LocalLDAModel.html). + +For the parameters provided to `LDA`: + +* `docConcentration`: Asymmetric priors can be used by passing in a +vector with values equal to the Dirichlet parameter in each of the `k` +dimensions. Values should be $>= 0$. Providing `Vector(-1)` results in +default behavior (uniform `k` dimensional vector with value $(1.0 / k)$) +* `topicConcentration`: Only symmetric priors supported. Values must be +$>= 0$. Providing `-1` results in defaulting to a value of $(1.0 / k)$. +* `maxIterations`: Maximum number of minibatches to submit. + +In addition, `OnlineLDAOptimizer` accepts the following parameters: + +* `miniBatchFraction`: Fraction of corpus sampled and used at each +iteration +* `optimizeDocConcentration`: If set to true, performs maximum-likelihood +estimation of the hyperparameter `docConcentration` (aka `alpha`) +after each minibatch and sets the optimized `docConcentration` in the +returned `LocalLDAModel` +* `tau0` and `kappa`: Used for learning-rate decay, which is computed by +$(\tau_0 + iter)^{-\kappa}$ where $iter$ is the current number of iterations. + +`OnlineLDAOptimizer` produces a `LocalLDAModel`, which only stores the +inferred topics. A `LocalLDAModel` supports: + +* `logLikelihood(documents)`: Calculates a lower bound on the provided +`documents` given the inferred topics. +* `logPerplexity(documents)`: Calculates an upper bound on the +perplexity of the provided `documents` given the inferred topics. **Examples** @@ -470,6 +589,7 @@ to the algorithm. We then output the topics, represented as probability distribu
    +Refer to the [`LDA` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) and [`DistributedLDAModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel) for details on the API. {% highlight scala %} import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} @@ -501,6 +621,8 @@ val sameModel = DistributedLDAModel.load(sc, "myLDAModel")
    +Refer to the [`LDA` Java docs](api/java/org/apache/spark/mllib/clustering/LDA.html) and [`DistributedLDAModel` Java docs](api/java/org/apache/spark/mllib/clustering/DistributedLDAModel.html) for details on the API. + {% highlight java %} import scala.Tuple2; @@ -564,12 +686,77 @@ public class JavaLDAExample { {% endhighlight %}
    +
    +Refer to the [`LDA` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.LDA) and [`LDAModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.LDAModel) for more details on the API. + +{% highlight python %} +from pyspark.mllib.clustering import LDA, LDAModel +from pyspark.mllib.linalg import Vectors + +# Load and parse the data +data = sc.textFile("data/mllib/sample_lda_data.txt") +parsedData = data.map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) +# Index documents with unique IDs +corpus = parsedData.zipWithIndex().map(lambda x: [x[1], x[0]]).cache() + +# Cluster the documents into three topics using LDA +ldaModel = LDA.train(corpus, k=3) + +# Output topics. Each is a distribution over words (matching word count vectors) +print("Learned topics (as distributions over vocab of " + str(ldaModel.vocabSize()) + " words):") +topics = ldaModel.topicsMatrix() +for topic in range(3): + print("Topic " + str(topic) + ":") + for word in range(0, ldaModel.vocabSize()): + print(" " + str(topics[word][topic])) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LDAModel.load(sc, "myModelPath") +{% endhighlight %} +
    + +
    + +## Bisecting k-means + +Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering. + +Bisecting k-means is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering). +Hierarchical clustering is one of the most commonly used method of cluster analysis which seeks to build a hierarchy of clusters. +Strategies for hierarchical clustering generally fall into two types: + +- Agglomerative: This is a "bottom up" approach: each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy. +- Divisive: This is a "top down" approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. + +Bisecting k-means algorithm is a kind of divisive algorithms. +The implementation in MLlib has the following parameters: + +* *k*: the desired number of leaf clusters (default: 4). The actual number could be smaller if there are no divisible leaf clusters. +* *maxIterations*: the max number of k-means iterations to split clusters (default: 20) +* *minDivisibleClusterSize*: the minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster (default: 1) +* *seed*: a random seed (default: hash value of the class name) + +**Examples** + +
    +
    +Refer to the [`BisectingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeans) and [`BisectingKMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeansModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala %} +
    + +
    +Refer to the [`BisectingKMeans` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeans.html) and [`BisectingKMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeansModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java %} +
    ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, +updating them as new data arrive. `spark.mllib` provides support for streaming k-means clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: @@ -601,6 +788,7 @@ This example shows how to estimate clusters on streaming data.
    +Refer to the [`StreamingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.StreamingKMeans) for details on the API. First we import the neccessary classes. @@ -651,6 +839,8 @@ ssc.awaitTermination()
    +Refer to the [`StreamingKMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.StreamingKMeans) for more details on the API. + First we import the neccessary classes. {% highlight python %} diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index eedc23424ad5..1ebb4654aef1 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - MLlib -displayTitle: MLlib - Collaborative Filtering +title: Collaborative Filtering - spark.mllib +displayTitle: Collaborative Filtering - spark.mllib --- * Table of contents @@ -11,12 +11,12 @@ displayTitle: MLlib - Collaborative Filtering [Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-item association matrix. MLlib currently supports +missing entries of a user-item association matrix. `spark.mllib` currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -MLlib uses the [alternating least squares +`spark.mllib` uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) -algorithm to learn these latent factors. The implementation in MLlib has the +algorithm to learn these latent factors. The implementation in `spark.mllib` has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). @@ -34,7 +34,7 @@ The standard approach to matrix factorization based collaborative filtering trea the entries in the user-item matrix as *explicit* preferences given by the user to the item. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with such data is taken +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). Essentially instead of trying to model the matrix of ratings directly, this approach treats the data @@ -64,43 +64,9 @@ We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.rec method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. -{% highlight scala %} -import org.apache.spark.mllib.recommendation.ALS -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel -import org.apache.spark.mllib.recommendation.Rating - -// Load and parse the data -val data = sc.textFile("data/mllib/als/test.data") -val ratings = data.map(_.split(',') match { case Array(user, item, rate) => - Rating(user.toInt, item.toInt, rate.toDouble) - }) - -// Build the recommendation model using ALS -val rank = 10 -val numIterations = 10 -val model = ALS.train(ratings, rank, numIterations, 0.01) - -// Evaluate the model on rating data -val usersProducts = ratings.map { case Rating(user, product, rate) => - (user, product) -} -val predictions = - model.predict(usersProducts).map { case Rating(user, product, rate) => - ((user, product), rate) - } -val ratesAndPreds = ratings.map { case Rating(user, product, rate) => - ((user, product), rate) -}.join(predictions) -val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => - val err = (r1 - r2) - err * err -}.mean() -println("Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} If the rating matrix is derived from another source of information (e.g., it is inferred from other signals), you can use the `trainImplicit` method to get better results. @@ -117,83 +83,11 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A self-contained application example -that is equivalent to the provided example in Scala is given bellow: - -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.SparkConf; - -public class CollaborativeFiltering { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/als/test.data"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); - - // Build the recommendation model using ALS - int rank = 10; - int numIterations = 10; - MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); - - // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); - System.out.println("Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +that is equivalent to the provided example in Scala is given below: + +Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
    @@ -201,29 +95,9 @@ In the following example we load rating data. Each row consists of a user, a pro We use the default ALS.train() method which assumes ratings are explicit. We evaluate the recommendation by measuring the Mean Squared Error of rating prediction. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating - -# Load and parse the data -data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) - -# Build the recommendation model using Alternating Least Squares -rank = 10 -numIterations = 10 -model = ALS.train(ratings, rank, numIterations) - -# Evaluate the model on training data -testdata = ratings.map(lambda p: (p[0], p[1])) -predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) -ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS) for more details on the API. + +{% include_example python/mllib/recommendation_example.py %} If the rating matrix is derived from other source of information (i.e., it is inferred from other signals), you can use the trainImplicit method to get better results. @@ -245,4 +119,4 @@ a dependency. ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for -[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). +[personalized movie recommendation with `spark.mllib`](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 3aa040046fca..363dc7c13b30 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global title: Data Types - MLlib -displayTitle: MLlib - Data Types +displayTitle: Data Types - MLlib --- * Table of contents @@ -33,6 +33,8 @@ implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.lin using the factory methods implemented in [`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. +Refer to the [`Vector` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -59,6 +61,8 @@ implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVec using the factory methods implemented in [`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. +Refer to the [`Vector` Java docs](api/java/org/apache/spark/mllib/linalg/Vector.html) and [`Vectors` Java docs](api/java/org/apache/spark/mllib/linalg/Vectors.html) for details on the API. + {% highlight java %} import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -86,6 +90,8 @@ and the following as sparse vectors: We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. +Refer to the [`Vectors` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) for more details on the API. + {% highlight python %} import numpy as np import scipy.sparse as sps @@ -119,6 +125,8 @@ For multiclass classification, labels should be class indices starting from zero A labeled point is represented by the case class [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint). +Refer to the [`LabeledPoint` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -136,6 +144,8 @@ val neg = LabeledPoint(0.0, Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0))) A labeled point is represented by [`LabeledPoint`](api/java/org/apache/spark/mllib/regression/LabeledPoint.html). +Refer to the [`LabeledPoint` Java docs](api/java/org/apache/spark/mllib/regression/LabeledPoint.html) for details on the API. + {% highlight java %} import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; @@ -144,7 +154,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)); // Create a labeled point with a negative label and a sparse feature vector. -LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); +LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); {% endhighlight %}
    @@ -153,6 +163,8 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new A labeled point is represented by [`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint). +Refer to the [`LabeledPoint` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) for more details on the API. + {% highlight python %} from pyspark.mllib.linalg import SparseVector from pyspark.mllib.regression import LabeledPoint @@ -187,6 +199,8 @@ After loading, the feature indices are converted to zero-based. [`MLUtils.loadLibSVMFile`](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. +Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils @@ -200,6 +214,8 @@ val examples: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_ [`MLUtils.loadLibSVMFile`](api/java/org/apache/spark/mllib/util/MLUtils.html) reads training examples stored in LIBSVM format. +Refer to the [`MLUtils` Java docs](api/java/org/apache/spark/mllib/util/MLUtils.html) for details on the API. + {% highlight java %} import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; @@ -214,6 +230,8 @@ JavaRDD examples = [`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training examples stored in LIBSVM format. +Refer to the [`MLUtils` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) for more details on the API. + {% highlight python %} from pyspark.mllib.util import MLUtils @@ -246,6 +264,8 @@ We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local matrices. Remember, local matrices in MLlib are stored in column-major order. +Refer to the [`Matrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix) and [`Matrices` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} @@ -267,6 +287,8 @@ We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local matrices. Remember, local matrices in MLlib are stored in column-major order. +Refer to the [`Matrix` Java docs](api/java/org/apache/spark/mllib/linalg/Matrix.html) and [`Matrices` Java docs](api/java/org/apache/spark/mllib/linalg/Matrices.html) for details on the API. + {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Matrices; @@ -289,6 +311,8 @@ We recommend using the factory methods implemented in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local matrices. Remember, local matrices in MLlib are stored in column-major order. +Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix) and [`Matrices` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) for more details on the API. + {% highlight python %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} @@ -337,7 +361,11 @@ limited by the integer range but it should be much smaller in practice.
    A [`RowMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) can be -created from an `RDD[Vector]` instance. Then we can compute its column summary statistics. +created from an `RDD[Vector]` instance. Then we can compute its column summary statistics and decompositions. +[QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition) is of the form A = QR where Q is an orthogonal matrix and R is an upper triangular matrix. +For [singular value decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition) and [principal component analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis), please refer to [Dimensionality reduction](mllib-dimensionality-reduction.html). + +Refer to the [`RowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) for details on the API. {% highlight scala %} import org.apache.spark.mllib.linalg.Vector @@ -350,6 +378,9 @@ val mat: RowMatrix = new RowMatrix(rows) // Get its size. val m = mat.numRows() val n = mat.numCols() + +// QR decomposition +val qrResult = mat.tallSkinnyQR(true) {% endhighlight %}
    @@ -358,6 +389,8 @@ val n = mat.numCols() A [`RowMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) can be created from a `JavaRDD` instance. Then we can compute its column summary statistics. +Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.linalg.Vector; @@ -370,14 +403,44 @@ RowMatrix mat = new RowMatrix(rows.rdd()); // Get its size. long m = mat.numRows(); long n = mat.numCols(); + +// QR decomposition +QRDecomposition result = mat.tallSkinnyQR(true); {% endhighlight %}
    + +
    + +A [`RowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) can be +created from an `RDD` of vectors. + +Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for more details on the API. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import RowMatrix + +# Create an RDD of vectors. +rows = sc.parallelize([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + +# Create a RowMatrix from an RDD of vectors. +mat = RowMatrix(rows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of vectors again. +rowsRDD = mat.rows +{% endhighlight %} +
    +
    ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local +vector.
    @@ -389,6 +452,8 @@ can be created from an `RDD[IndexedRow]` instance, where wrapper over `(Long, Vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping its row indices. +Refer to the [`IndexedRowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix, RowMatrix} @@ -414,6 +479,8 @@ can be created from an `JavaRDD` instance, where wrapper over `(long, Vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping its row indices. +Refer to the [`IndexedRowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.linalg.distributed.IndexedRow; @@ -431,7 +498,53 @@ long n = mat.numCols(); // Drop its row indices. RowMatrix rowMat = mat.toRowMatrix(); {% endhighlight %} -
    +
    + +
    + +An [`IndexedRowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRowMatrix) +can be created from an `RDD` of `IndexedRow`s, where +[`IndexedRow`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRow) is a +wrapper over `(long, vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping +its row indices. + +Refer to the [`IndexedRowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRowMatrix) for more details on the API. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix + +# Create an RDD of indexed rows. +# - This can be done explicitly with the IndexedRow class: +indexedRows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + IndexedRow(1, [4, 5, 6]), + IndexedRow(2, [7, 8, 9]), + IndexedRow(3, [10, 11, 12])]) +# - or by using (long, vector) tuples: +indexedRows = sc.parallelize([(0, [1, 2, 3]), (1, [4, 5, 6]), + (2, [7, 8, 9]), (3, [10, 11, 12])]) + +# Create an IndexedRowMatrix from an RDD of IndexedRows. +mat = IndexedRowMatrix(indexedRows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of IndexedRows. +rowsRDD = mat.rows + +# Convert to a RowMatrix by dropping the row indices. +rowMat = mat.toRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    + +
    ### CoordinateMatrix @@ -451,6 +564,8 @@ wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to a with sparse rows by calling `toIndexedRowMatrix`. Other computations for `CoordinateMatrix` are not currently supported. +Refer to the [`CoordinateMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.CoordinateMatrix) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} @@ -477,6 +592,8 @@ wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to a with sparse rows by calling `toIndexedRowMatrix`. Other computations for `CoordinateMatrix` are not currently supported. +Refer to the [`CoordinateMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; @@ -495,6 +612,47 @@ long n = mat.numCols(); IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %}
    + +
    + +A [`CoordinateMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.CoordinateMatrix) +can be created from an `RDD` of `MatrixEntry` entries, where +[`MatrixEntry`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.MatrixEntry) is a +wrapper over `(long, long, float)`. A `CoordinateMatrix` can be converted to a `RowMatrix` by +calling `toRowMatrix`, or to an `IndexedRowMatrix` with sparse rows by calling `toIndexedRowMatrix`. + +Refer to the [`CoordinateMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.CoordinateMatrix) for more details on the API. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry + +# Create an RDD of coordinate entries. +# - This can be done explicitly with the MatrixEntry class: +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +# - or using (long, long, float) tuples: +entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) + +# Create an CoordinateMatrix from an RDD of MatrixEntries. +mat = CoordinateMatrix(entries) + +# Get its size. +m = mat.numRows() # 3 +n = mat.numCols() # 2 + +# Get the entries as an RDD of MatrixEntries. +entriesRDD = mat.entries + +# Convert to a RowMatrix. +rowMat = mat.toRowMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    +
    ### BlockMatrix @@ -514,6 +672,8 @@ most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix` creates blocks of size 1024 x 1024 by default. Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. +Refer to the [`BlockMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} @@ -539,6 +699,8 @@ most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix` creates blocks of size 1024 x 1024 by default. Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. +Refer to the [`BlockMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.linalg.distributed.BlockMatrix; @@ -559,4 +721,41 @@ matA.validate(); BlockMatrix ata = matA.transpose().multiply(matA); {% endhighlight %}
    + +
    + +A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) +can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a +`((blockRowIndex, blockColIndex), sub-matrix)` tuple. + +Refer to the [`BlockMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) for more details on the API. + +{% highlight python %} +from pyspark.mllib.linalg import Matrices +from pyspark.mllib.linalg.distributed import BlockMatrix + +# Create an RDD of sub-matrix blocks. +blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + +# Create a BlockMatrix from an RDD of sub-matrix blocks. +mat = BlockMatrix(blocks, 3, 2) + +# Get its size. +m = mat.numRows() # 6 +n = mat.numCols() # 2 + +# Get the blocks as an RDD of sub-matrix blocks. +blocksRDD = mat.blocks + +# Convert to a LocalMatrix. +localMat = mat.toLocalMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() +{% endhighlight %} +
    diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index c1d0f8a6b1cd..a8612b6c84fe 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - MLlib -displayTitle: MLlib - Decision Trees +title: Decision Trees - spark.mllib +displayTitle: Decision Trees - spark.mllib --- * Table of contents @@ -15,7 +15,7 @@ feature scaling, and are able to capture non-linearities and feature interaction algorithms such as random forests and boosting are among the top performers for classification and regression tasks. -MLlib supports decision trees for binary and multiclass classification and for regression, +`spark.mllib` supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances. @@ -191,135 +191,22 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "gini" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala %}
    -
    -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model for classification. -final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - impurity='gini', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. + +{% include_example python/mllib/decision_tree_classification_example.py %}
    @@ -335,140 +222,22 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "variance" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala %}
    -
    -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model. -final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, - impurity='variance', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. + +{% include_example python/mllib/decision_tree_regression_example.py %}
    diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 05f51168d837..11d8e0bd1d23 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - MLlib -displayTitle: MLlib - Dimensionality Reduction +title: Dimensionality Reduction - spark.mllib +displayTitle: Dimensionality Reduction - spark.mllib --- * Table of contents @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on the RowMatrix class. +`spark.mllib` provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -57,11 +57,13 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. ### SVD Example -MLlib provides SVD functionality to row-oriented matrices, provided in the +`spark.mllib` provides SVD functionality to row-oriented matrices, provided in the RowMatrix class.
    +Refer to the [`SingularValueDecomposition` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.SingularValueDecomposition) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.Matrix import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -80,6 +82,8 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`.
    +Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/mllib/linalg/SingularValueDecomposition.html) for details on the API. + {% highlight java %} import java.util.LinkedList; @@ -137,7 +141,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. +`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    @@ -145,6 +149,8 @@ MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format an The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. +Refer to the [`RowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.Matrix import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -161,6 +167,8 @@ val projected: RowMatrix = mat.multiply(pc) The following code demonstrates how to compute principal components on source vectors and use them to project the vectors into a low-dimensional space while keeping associated labels: +Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.feature.PCA @@ -182,6 +190,8 @@ The following code demonstrates how to compute principal components on a `RowMat and use them to project the vectors into a low-dimensional space. The number of columns should be small, e.g, less than 1000. +Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. + {% highlight java %} import java.util.LinkedList; diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 7521fb14a7bd..2416b6fa0aeb 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles - spark.mllib +displayTitle: Ensembles - spark.mllib --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBosotedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -33,9 +33,9 @@ Like decision trees, random forests handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports random forests for binary and multiclass classification and for regression, +`spark.mllib` supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -MLlib implements random forests using the existing [decision tree](mllib-decision-tree.html) +`spark.mllib` implements random forests using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. ### Basic algorithm @@ -95,142 +95,22 @@ The test error is calculated to measure the algorithm accuracy.
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "gini" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala %}
    -
    -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -HashMap categoricalFeaturesInfo = new HashMap(); -Integer numTrees = 3; // Use more in practice. -String featureSubsetStrategy = "auto"; // Let the algorithm choose. -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; -Integer seed = 12345; - -final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. + +{% include_example python/mllib/random_forest_classification_example.py %}
    @@ -246,145 +126,22 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "variance" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala %}
    -
    -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 4; -Integer maxBins = 32; - -// Train a RandomForest model. -final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. + +{% include_example python/mllib/random_forest_regression_example.py %}
    @@ -398,9 +155,9 @@ Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports GBTs for binary classification and for regression, +`spark.mllib` supports GBTs for binary classification and for regression, using both continuous and categorical features. -MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. +`spark.mllib` implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. *Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use [decision trees](mllib-decision-tree.html) or [Random Forests](mllib-ensembles.html#Random-Forest). @@ -414,7 +171,7 @@ The specific mechanism for re-labeling instances is defined by a loss function ( #### Losses -The table below lists the losses currently supported by GBTs in MLlib. +The table below lists the losses currently supported by GBTs in `spark.mllib`. Note that each loss is applicable to one of classification or regression, not both. Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. @@ -479,139 +236,22 @@ The test error is calculated to measure the algorithm accuracy.
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams("Classification") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClasses = 2 -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala %}
    -
    -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainClassifier(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. + +{% include_example python/mllib/gradient_boosting_classification_example.py %}
    @@ -627,144 +267,22 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    -
    -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams("Regression") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala %}
    -
    -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java %}
    -
    - -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainRegressor(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +
    +Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. + +{% include_example python/mllib/gradient_boosting_regression_example.py %}
    diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7066d5c97418..774826c2703f 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,20 +1,20 @@ --- layout: global -title: Evaluation Metrics - MLlib -displayTitle: MLlib - Evaluation Metrics +title: Evaluation Metrics - spark.mllib +displayTitle: Evaluation Metrics - spark.mllib --- * Table of contents {:toc} -Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +`spark.mllib` comes with a number of machine learning algorithms that can be used to learn from and make predictions on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance -of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +of the model on some criteria, which depends on the application and its requirements. `spark.mllib` also provides a suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those -metrics that are currently available in Spark's MLlib are detailed in this section. +metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation @@ -102,213 +102,23 @@ The following code snippets illustrate how to load a sample dataset, train a bin data, and evaluate the performance of the algorithm by several binary evaluation metrics.
    +Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.precisionByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    +Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    +Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -428,203 +238,23 @@ The following code snippets illustrate how to load a sample dataset, train a mul the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics.
    +Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    +Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    +Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -758,153 +388,23 @@ True classes:
    +Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    +Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    +Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1016,279 +516,23 @@ expanded world of non-positive weights are "the same as never having interacted
    +Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    +Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    +Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1336,162 +580,23 @@ The following code snippets illustrate how to load a sample dataset, train a lin and evaluate the performance of the algorithm by several regression metrics.
    +Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    +Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    +Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    -
    \ No newline at end of file +
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index de86aba2ae62..7796bac69756 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - MLlib -displayTitle: MLlib - Feature Extraction and Transformation +title: Feature Extraction and Transformation - spark.mllib +displayTitle: Feature Extraction and Transformation - spark.mllib --- * Table of contents @@ -31,7 +31,7 @@ The TF-IDF measure is simply the product of TF and IDF: TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). \]` There are several variants on the definition of term frequency and document frequency. -In MLlib, we separate TF and IDF to make them flexible. +In `spark.mllib`, we separate TF and IDF to make them flexible. Our implementation of term frequency utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). @@ -44,7 +44,7 @@ To reduce the chance of collision, we can increase the target feature dimension, the number of buckets of the hash table. The default feature dimension is `$2^{20} = 1,048,576$`. -**Note:** MLlib doesn't provide tools for text segmentation. +**Note:** `spark.mllib` doesn't provide tools for text segmentation. We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and [scalanlp/chalk](https://github.com/scalanlp/chalk). @@ -56,6 +56,9 @@ and [IDF](api/scala/index.html#org.apache.spark.mllib.feature.IDF). `HashingTF` takes an `RDD[Iterable[_]]` as the input. Each record could be an iterable of strings or other types. +Refer to the [`HashingTF` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.HashingTF) for details on the API. + + {% highlight scala %} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext @@ -83,7 +86,7 @@ val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} -MLlib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -103,6 +106,9 @@ and [IDF](api/python/pyspark.mllib.html#pyspark.mllib.feature.IDF). `HashingTF` takes an RDD of list as the input. Each record could be an iterable of strings or other types. + +Refer to the [`HashingTF` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF) for details on the API. + {% highlight python %} from pyspark import SparkContext from pyspark.mllib.feature import HashingTF @@ -128,7 +134,7 @@ idf = IDF().fit(tf) tfidf = idf.transform(tf) {% endhighlight %} -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -183,7 +189,9 @@ the [text8](http://mattmahoney.net/dc/text8.zip) data and extract it to your pre Here we assume the extracted file is `text8` and in same directory as you run the spark shell.
    -
    +
    +Refer to the [`Word2Vec` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.Word2Vec) for details on the API. + {% highlight scala %} import org.apache.spark._ import org.apache.spark.rdd._ @@ -207,7 +215,9 @@ model.save(sc, "myModelPath") val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    -
    +
    +Refer to the [`Word2Vec` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.Word2Vec) for more details on the API. + {% highlight python %} from pyspark import SparkContext from pyspark.mllib.feature import Word2Vec @@ -264,7 +274,9 @@ The example below demonstrates how to load a dataset in libsvm format, and stand so that the new features have unit standard deviation and/or zero mean.
    -
    +
    +Refer to the [`StandardScaler` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.StandardScaler @@ -288,7 +300,9 @@ val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.t {% endhighlight %}
    -
    +
    +Refer to the [`StandardScaler` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.StandardScaler) for more details on the API. + {% highlight python %} from pyspark.mllib.util import MLUtils from pyspark.mllib.linalg import Vectors @@ -338,7 +352,9 @@ The example below demonstrates how to load a dataset in libsvm format, and norma with $L^2$ norm, and $L^\infty$ norm.
    -
    +
    +Refer to the [`Normalizer` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.Normalizer) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.Normalizer @@ -358,7 +374,9 @@ val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) {% endhighlight %}
    -
    +
    +Refer to the [`Normalizer` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.Normalizer) for more details on the API. + {% highlight python %} from pyspark.mllib.util import MLUtils from pyspark.mllib.linalg import Vectors @@ -380,35 +398,43 @@ data2 = labels.zip(normalizer2.transform(features))
    -## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
    -
    +
    + +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -434,7 +460,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
    -
    +
    + +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -486,7 +516,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -506,7 +541,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. @@ -515,7 +550,10 @@ v_N This example below demonstrates how to transform vectors using a transforming vector value.
    -
    +
    + +Refer to the [`ElementwiseProduct` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct @@ -534,7 +572,9 @@ val transformedData2 = data.map(x => transformer.transform(x)) {% endhighlight %}
    -
    +
    +Refer to the [`ElementwiseProduct` Java docs](api/java/org/apache/spark/mllib/feature/ElementwiseProduct.html) for details on the API. + {% highlight java %} import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -563,7 +603,9 @@ JavaRDD transformedData2 = data.map( {% endhighlight %}
    -
    +
    +Refer to the [`ElementwiseProduct` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.ElementwiseProduct) for more details on the API. + {% highlight python %} from pyspark import SparkContext from pyspark.mllib.linalg import Vectors @@ -600,7 +642,9 @@ and use them to project the vectors into a low-dimensional space while keeping a for calculation a [Linear Regression]((mllib-linear-methods.html))
    -
    +
    +Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index bcc066a18552..2c8a8f236163 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - MLlib -displayTitle: MLlib - Frequent Pattern Mining +title: Frequent Pattern Mining - spark.mllib +displayTitle: Frequent Pattern Mining - spark.mllib --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the @@ -9,7 +9,7 @@ first steps to analyze a large-scale dataset, which has been an active research data mining for years. We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) for more information. -MLlib provides a parallel implementation of FP-growth, +`spark.mllib` provides a parallel implementation of FP-growth, a popular algorithm to mining frequent itemsets. ## FP-growth @@ -22,13 +22,13 @@ Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) al the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. -In MLlib, we implemented a parallel version of FP-growth called PFP, +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffices of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. -MLlib's FP-growth implementation takes the following (hyper-)parameters: +`spark.mllib`'s FP-growth implementation takes the following (hyper-)parameters: * `minSupport`: the minimum support for an itemset to be identified as frequent. For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. @@ -41,26 +41,18 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters: [`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) -that stores the frequent itemsets with their frequencies. - -{% highlight scala %} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} - -val transactions: RDD[Array[String]] = ... +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. -val fpg = new FPGrowth() - .setMinSupport(0.2) - .setNumPartitions(10) -val model = fpg.run(transactions) +Refer to the [`FPGrowth` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) for details on the API. -model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala %}
    @@ -68,31 +60,123 @@ model.freqItemsets.collect().foreach { itemset => [`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the FP-growth algorithm. -It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. + +Refer to the [`FPGrowth` Java docs](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java %} + +
    + +
    + +[`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +Calling `FPGrowth.train` with transactions returns an +[`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. -{% highlight java %} -import java.util.List; +Refer to the [`FPGrowth` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) for more details on the API. + +{% include_example python/mllib/fpgrowth_example.py %} + +
    + +
    + +## Association Rules + +
    +
    +[AssociationRules](api/scala/index.html#org.apache.spark.mllib.fpm.AssociationRules) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +Refer to the [`AssociationRules` Scala docs](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) for details on the API. -import com.google.common.base.Joiner; +{% include_example scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.fpm.FPGrowth; -import org.apache.spark.mllib.fpm.FPGrowthModel; +
    -JavaRDD> transactions = ... +
    +[AssociationRules](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. -FPGrowth fpg = new FPGrowth() - .setMinSupport(0.2) - .setNumPartitions(10); -FPGrowthModel model = fpg.run(transactions); +Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) for details on the API. -for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java %}
    + +## PrefixSpan + +PrefixSpan is a sequential pattern mining algorithm described in +[Pei et al., Mining Sequential Patterns by Pattern-Growth: The +PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +the reader to the referenced paper for formalizing the sequential +pattern mining problem. + +`spark.mllib`'s PrefixSpan implementation takes the following parameters: + +* `minSupport`: the minimum support required to be considered a frequent + sequential pattern. +* `maxPatternLength`: the maximum length of a frequent sequential + pattern. Any frequent pattern exceeding this length will not be + included in the results. +* `maxLocalProjDBSize`: the maximum number of items allowed in a + prefix-projected database before local iterative processing of the + projected databse begins. This parameter should be tuned with respect + to the size of your executors. + +**Examples** + +The following example illustrates PrefixSpan running on the sequences +(using same notation as Pei et al): + +~~~ + <(12)3> + <1(32)(12)> + <(12)5> + <6> +~~~ + +
    +
    + +[`PrefixSpan`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) +that stores the frequent sequences with their frequencies. + +Refer to the [`PrefixSpan` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) and [`PrefixSpanModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala %} + +
    + +
    + +[`PrefixSpan`](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) +that stores the frequent sequences with their frequencies. + +Refer to the [`PrefixSpan` Java docs](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) and [`PrefixSpanModel` Java docs](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java %} + +
    +
    + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index eea864eacf7c..7ef91a178ccd 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -5,117 +5,131 @@ displayTitle: Machine Learning Library (MLlib) Guide description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- -MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, -including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives. -Guides for individual algorithms are listed below. +MLlib is Spark's machine learning (ML) library. +Its goal is to make practical machine learning scalable and easy. +It consists of common learning algorithms and utilities, including classification, regression, +clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization +primitives and higher-level pipeline APIs. -The API is divided into 2 parts: +It divides into two packages: -* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. -* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API + built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). +* [`spark.ml`](ml-guide.html) provides higher-level API + built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. -We list major functionality from both below, with links to detailed guides. +Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. +But we will keep supporting `spark.mllib` along with the development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.ml` if they fit the ML pipeline concept well, +e.g., feature extractors and transformers. -# MLlib types, algorithms and utilities +We list major functionality from both below, with links to detailed guides. -This lists functionality included in `spark.mllib`, the main MLlib API. +# spark.mllib: data types, algorithms, and utilities * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) - * summary statistics - * correlations - * stratified sampling - * hypothesis testing - * random data generation + * [summary statistics](mllib-statistics.html#summary-statistics) + * [correlations](mllib-statistics.html#correlations) + * [stratified sampling](mllib-statistics.html#stratified-sampling) + * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [streaming significance testing](mllib-statistics.html#streaming-significance-testing) + * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) - * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [ensembles of trees (Random Forests and Gradient-Boosted Trees)](mllib-ensembles.html) * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) - * alternating least squares (ALS) + * [alternating least squares (ALS)](mllib-collaborative-filtering.html#collaborative-filtering) * [Clustering](mllib-clustering.html) * [k-means](mllib-clustering.html#k-means) * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [bisecting k-means](mllib-clustering.html#bisecting-kmeans) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) - * singular value decomposition (SVD) - * principal component analysis (PCA) + * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) + * [principal component analysis (PCA)](mllib-dimensionality-reduction.html#principal-component-analysis-pca) * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) - * FP-growth -* [Evaluation Metrics](mllib-evaluation-metrics.html) -* [Optimization (developer)](mllib-optimization.html) - * stochastic gradient descent - * limited-memory BFGS (L-BFGS) + * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) + * [association rules](mllib-frequent-pattern-mining.html#association-rules) + * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) +* [Evaluation metrics](mllib-evaluation-metrics.html) * [PMML model export](mllib-pmml-model-export.html) - -MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and the migration guide below will explain all changes between releases. +* [Optimization (developer)](mllib-optimization.html) + * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) + * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) # spark.ml: high-level APIs for ML pipelines -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. +* [Overview: estimators, transformers and pipelines](ml-guide.html) +* [Extracting, transforming and selecting features](ml-features.html) +* [Classification and regression](ml-classification-regression.html) +* [Clustering](ml-clustering.html) +* [Advanced topics](ml-advanced.html) -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. +Some techniques are not available yet in spark.ml, most notably dimensionality reduction +Users can seamlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -More detailed guides for `spark.ml` include: +# Dependencies -* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts -* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. +If natives libraries[^1] are not available at runtime, you will see a warning message and a pure JVM +implementation will be used instead. -# Dependencies +Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native +proxies by default. +To configure `netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your +project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your +platform's additional installation instructions. -MLlib uses the linear algebra package -[Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java) for optimised -numerical processing. If natives are not available at runtime, you -will see a warning message and a pure JVM implementation will be used -instead. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. -To learn more about the benefits and background of system optimised -natives, you may wish to watch Sam Halliday's ScalaX talk on -[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). +[^1]: To learn more about the benefits and background of system optimised natives, you may wish to + watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -Due to licensing issues with runtime proprietary binaries, we do not -include `netlib-java`'s native proxies by default. To configure -`netlib-java` / Breeze to use system optimised binaries, include -`com.github.fommil.netlib:all:1.1.2` (or build Spark with -`-Pnetlib-lgpl`) as a dependency of your project and read the -[netlib-java](https://github.com/fommil/netlib-java) documentation for -your platform's additional installation instructions. +# Migration guide -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) -version 1.4 or newer. +MLlib is under active development. +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +and the migration guide below will explain all changes between releases. ---- +## From 1.5 to 1.6 -# Migration Guide +There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are +deprecations and changes of behavior. -For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). +Deprecations: -## From 1.3 to 1.4 +* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): + In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. +* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): + In `spark.ml.classification.LogisticRegressionModel` and + `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of + the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to + algorithms. -In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: +Changes of behavior: -* Gradient-Boosted Trees - * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. - * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. -* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. +* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): + `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. + Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of + `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the + previous error); for small errors (`< 0.01`), it uses absolute error. +* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): + `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before + tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the + behavior of the simpler `Tokenizer` transformer. -## Previous Spark Versions +## Previous Spark versions Earlier migration guides are archived [on this page](mllib-migration-guides.html). + +--- diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 5732bc4c7e79..8ede4407d584 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - MLlib -displayTitle: MLlib - Regression +title: Isotonic regression - spark.mllib +displayTitle: Regression - spark.mllib --- ## Isotonic regression @@ -23,7 +23,7 @@ Essentially isotonic regression is a [monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) best fitting the original data points. -MLlib supports a +`spark.mllib` supports a [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). @@ -59,105 +59,28 @@ i.e. 4710.28,500.00. The data are split to training and testing set. Model is created using the training set and a mean squared error is calculated from the predicted labels and real labels in the test set. -{% highlight scala %} -import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} +Refer to the [`IsotonicRegression` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.IsotonicRegression) and [`IsotonicRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.IsotonicRegressionModel) for details on the API. -val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") - -// Create label, feature, weight tuples from input data with weight set to default value 1.0. -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - (parts(0), parts(1), 1.0) -} - -// Split data into training (60%) and test (40%) sets. -val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0) -val test = splits(1) - -// Create isotonic regression model from training data. -// Isotonic parameter defaults to true so it is only shown for demonstration -val model = new IsotonicRegression().setIsotonic(true).run(training) - -// Create tuples of predicted and real labels. -val predictionAndLabel = test.map { point => - val predictedLabel = model.predict(point._2) - (predictedLabel, point._1) -} - -// Calculate mean squared error between predicted and real labels. -val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() -println("Mean Squared Error = " + meanSquaredError) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala %}
    -
    Data are read from a file where each line has a format label,feature i.e. 4710.28,500.00. The data are split to training and testing set. Model is created using the training set and a mean squared error is calculated from the predicted labels and real labels in the test set. -{% highlight java %} -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.IsotonicRegressionModel; -import scala.Tuple2; -import scala.Tuple3; - -JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); - -// Create label, feature, weight tuples from input data with weight set to default value 1.0. -JavaRDD> parsedData = data.map( - new Function>() { - public Tuple3 call(String line) { - String[] parts = line.split(","); - return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); - } - } -); +Refer to the [`IsotonicRegression` Java docs](api/java/org/apache/spark/mllib/regression/IsotonicRegression.html) and [`IsotonicRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/IsotonicRegressionModel.html) for details on the API. -// Split data into training (60%) and test (40%) sets. -JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); -JavaRDD> training = splits[0]; -JavaRDD> test = splits[1]; - -// Create isotonic regression model from training data. -// Isotonic parameter defaults to true so it is only shown for demonstration -IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); - -// Create tuples of predicted and real labels. -JavaPairRDD predictionAndLabel = test.mapToPair( - new PairFunction, Double, Double>() { - @Override public Tuple2 call(Tuple3 point) { - Double predictedLabel = model.predict(point._2()); - return new Tuple2(predictedLabel, point._1()); - } - } -); - -// Calculate mean squared error between predicted and real labels. -Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( - new Function, Object>() { - @Override public Object call(Tuple2 pl) { - return Math.pow(pl._1() - pl._2(), 2); - } - } -).rdd()).mean(); +{% include_example java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java %} +
    +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. -System.out.println("Mean Squared Error = " + meanSquaredError); +Refer to the [`IsotonicRegression` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.IsotonicRegression) and [`IsotonicRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.IsotonicRegressionModel) for more details on the API. -// Save and load model -model.save(sc.sc(), "myModelPath"); -IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example python/mllib/isotonic_regression_example.py %}
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 07655baa414b..20b35612cab9 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - MLlib -displayTitle: MLlib - Linear Methods +title: Linear Methods - spark.mllib +displayTitle: Linear Methods - spark.mllib --- * Table of contents @@ -41,7 +41,7 @@ the objective function is of the form Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several of MLlib's classification and regression algorithms fall into this category, +Several of `spark.mllib`'s classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: @@ -55,7 +55,7 @@ training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions The following table summarizes the loss functions and their gradients or sub-gradients for the -methods MLlib supports: +methods `spark.mllib` supports: @@ -83,7 +83,7 @@ methods MLlib supports: The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to encourage simple models and avoid overfitting. We support the following -regularizers in MLlib: +regularizers in `spark.mllib`:
    @@ -115,27 +115,30 @@ especially when the number of training examples is small. ### Optimization -Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. +Under the hood, linear methods use convex optimization methods to optimize the objective functions. +`spark.mllib` uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). +Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. +Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. ## Classification [Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into categories. The most common classification type is -[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), where there are two categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). -MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +`spark.mllib` supports two linear methods for classification: linear Support Vector Machines (SVMs) and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. -For both methods, MLlib supports L1 and L2 regularized variants. +For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either $+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -165,6 +168,8 @@ training algorithm on this training data using a static method in the algorithm object, and make predictions with the resulting model to compute the training error. +Refer to the [`SVMWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMWithSGD) and [`SVMModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMModel) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -205,7 +210,7 @@ val sameModel = SVMModel.load(sc, "myModelPath") The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -228,7 +233,9 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A self-contained application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: + +Refer to the [`SVMWithSGD` Java docs](api/java/org/apache/spark/mllib/classification/SVMWithSGD.html) and [`SVMModel` Java docs](api/java/org/apache/spark/mllib/classification/SVMModel.html) for details on the API. {% highlight java %} import scala.Tuple2; @@ -289,7 +296,7 @@ public class SVMClassifier { The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -316,6 +323,8 @@ a dependency. The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. +Refer to the [`SVMWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMWithSGD) and [`SVMModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMModel) for more details on the API. + {% highlight python %} from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint @@ -369,7 +378,7 @@ Binary logistic regression can be generalized into train and predict multiclass classification problems. For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately regressed against the pivot outcome. -In MLlib, the first class $0$ is chosen as the "pivot" class. +In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. @@ -395,6 +404,8 @@ test, and use to fit a logistic regression model. Then the model is evaluated against the test dataset and saved to disk. +Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionModel) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} @@ -441,6 +452,8 @@ test, and use to fit a logistic regression model. Then the model is evaluated against the test dataset and saved to disk. +Refer to the [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) and [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) for details on the API. + {% highlight java %} import scala.Tuple2; @@ -501,10 +514,11 @@ and make predictions with the resulting model to compute the training error. Note that the Python API does not yet support multiclass classification and model save/load but will in the future. +Refer to the [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionModel) for more details on the API. + {% highlight python %} from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -559,6 +573,8 @@ The example then uses LinearRegressionWithSGD to build a simple linear model to values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). +Refer to the [`LinearRegressionWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionModel) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LinearRegressionModel @@ -599,7 +615,9 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. The corresponding Java example to -the Scala snippet provided, is presented bellow: +the Scala snippet provided, is presented below: + +Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API. {% highlight java %} import scala.Tuple2; @@ -674,9 +692,10 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. +Refer to the [`LinearRegressionWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionModel) for more details on the API. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel -from numpy import array # Load and parse the data def parsePoint(line): @@ -710,7 +729,7 @@ a dependency. ###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +updating the parameters of the model as new data arrives. `spark.mllib` currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -836,7 +855,7 @@ will get better! # Implementation (developer) -Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent +Behind the scene, `spark.mllib` implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the optimization section). All provided algorithms take as input a regularization parameter (`regParam`) along with various parameters associated with stochastic diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 8df68d81f3c7..f3daef2dbadb 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,12 +1,50 @@ --- layout: global -title: Old Migration Guides - MLlib -displayTitle: MLlib - Old Migration Guides +title: Old Migration Guides - spark.mllib +displayTitle: Old Migration Guides - spark.mllib description: MLlib migration guides from before Spark SPARK_VERSION_SHORT --- The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.4 to 1.5 + +In the `spark.mllib` package, there are no breaking API changes but several behavior changes: + +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. + +In the `spark.ml` package, there exists one breaking API change and one behavior change: + +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. + +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. + +In the `spark.ml` package, several major API changes occurred, including: + +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes + +Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. +However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API +changes for future releases. + ## From 1.2 to 1.3 In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. @@ -23,6 +61,17 @@ In the `spark.mllib` package, there were several breaking changes. The first ch * In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index e73bd30f3a90..d0d594af6a4a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - MLlib -displayTitle: MLlib - Naive Bayes +title: Naive Bayes - spark.mllib +displayTitle: Naive Bayes - spark.mllib --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple @@ -12,7 +12,7 @@ distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. -MLlib supports [multinomial naive +`spark.mllib` supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). @@ -38,32 +38,10 @@ smoothing parameter `lambda` as input, an optional model type parameter (default [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. -{% highlight scala %} -import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint +Refer to the [`NaiveBayes` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes) and [`NaiveBayesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel) for details on the API. -val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") -val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -} -// Split data into training (60%) and test (40%). -val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0) -val test = splits(1) - -val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") - -val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) -val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = NaiveBayesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala %} -
    [NaiveBayes](api/java/org/apache/spark/mllib/classification/NaiveBayes.html) implements @@ -73,40 +51,10 @@ optionally smoothing parameter `lambda` as input, and output a [NaiveBayesModel](api/java/org/apache/spark/mllib/classification/NaiveBayesModel.html), which can be used for evaluation and prediction. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.classification.NaiveBayes; -import org.apache.spark.mllib.classification.NaiveBayesModel; -import org.apache.spark.mllib.regression.LabeledPoint; - -JavaRDD training = ... // training set -JavaRDD test = ... // test set - -final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); - -JavaPairRDD predictionAndLabel = - test.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -double accuracy = predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return pl._1().equals(pl._2()); - } - }).count() / (double) test.count(); +Refer to the [`NaiveBayes` Java docs](api/java/org/apache/spark/mllib/classification/NaiveBayes.html) and [`NaiveBayesModel` Java docs](api/java/org/apache/spark/mllib/classification/NaiveBayesModel.html) for details on the API. -// Save and load model -model.save(sc.sc(), "myModelPath"); -NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java %}
    -
    [NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial @@ -118,33 +66,8 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. -{% highlight python %} -from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint - -def parseLine(line): - parts = line.split(',') - label = float(parts[0]) - features = Vectors.dense([float(x) for x in parts[1].split(' ')]) - return LabeledPoint(label, features) - -data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) - -# Split data aproximately into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 0) - -# Train a naive Bayes model. -model = NaiveBayes.train(training, 1.0) - -# Make prediction and test accuracy. -predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) -accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() - -# Save and load model -model.save(sc, "myModelPath") -sameModel = NaiveBayesModel.load(sc, "myModelPath") -{% endhighlight %} +Refer to the [`NaiveBayes` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) and [`NaiveBayesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel) for more details on the API. +{% include_example python/mllib/naive_bayes_example.py %}
    diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 6cabc1610a15..f90b66f8e2c4 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - MLlib -displayTitle: MLlib - Optimization +title: Optimization - spark.mllib +displayTitle: Optimization - spark.mllib --- * Table of contents @@ -87,7 +87,7 @@ in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that s step-size for SGD methods can often be delicate in practice and is a topic of active research. **Gradients.** -A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in +A table of (sub)gradients of the machine learning methods implemented in `spark.mllib`, is available in the classification and regression section. @@ -140,7 +140,7 @@ other first-order optimization. ### Choosing an Optimization Method -[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in MLlib support both SGD and L-BFGS. +[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in `spark.mllib` support both SGD and L-BFGS. Different optimization methods can have different convergence guarantees depending on the properties of the objective function, and we cannot cover the literature here. In general, when L-BFGS is available, we recommend using it instead of SGD since L-BFGS tends to converge faster (in fewer iterations). @@ -218,152 +218,15 @@ L-BFGS optimizer.
    -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val numFeatures = data.take(1)(0).features.size - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - -// Append 1 into the training data as intercept. -val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() - -val test = splits(1) - -// Run training algorithm to build the model -val numCorrections = 10 -val convergenceTol = 1e-4 -val maxNumIterations = 20 -val regParam = 0.1 -val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) - -val (weightsWithIntercept, loss) = LBFGS.runLBFGS( - training, - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept) - -val model = new LogisticRegressionModel( - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), - weightsWithIntercept(weightsWithIntercept.size - 1)) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Loss of each step in training process") -loss.foreach(println) -println("Area under ROC = " + auROC) -{% endhighlight %} +Refer to the [`LBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) and [`SquaredL2Updater` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.SquaredL2Updater) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/LBFGSExample.scala %}
    -{% highlight java %} -import java.util.Arrays; -import java.util.Random; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.optimization.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class LBFGSExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - int numFeatures = data.take(1).get(0).features().size(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD trainingInit = data.sample(false, 0.6, 11L); - JavaRDD test = data.subtract(trainingInit); - - // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); - training.cache(); - - // Run training algorithm to build the model. - int numCorrections = 10; - double convergenceTol = 1e-4; - int maxNumIterations = 20; - double regParam = 0.1; - Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); - - Tuple2 result = LBFGS.runLBFGS( - training.rdd(), - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept); - Vector weightsWithIntercept = result._1(); - double[] loss = result._2(); - - final LogisticRegressionModel model = new LogisticRegressionModel( - Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), - (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(scoreAndLabels.rdd()); - double auROC = metrics.areaUnderROC(); - - System.out.println("Loss of each step in training process"); - for (double l : loss) - System.out.println(l); - System.out.println("Area under ROC = " + auROC); - } -} -{% endhighlight %} +Refer to the [`LBFGS` Java docs](api/java/org/apache/spark/mllib/optimization/LBFGS.html) and [`SquaredL2Updater` Java docs](api/java/org/apache/spark/mllib/optimization/SquaredL2Updater.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaLBFGSExample.java %}
    diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 42ea2ca81f80..b532ad907dfc 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,21 +1,21 @@ --- layout: global -title: PMML model export - MLlib -displayTitle: MLlib - PMML model export +title: PMML model export - spark.mllib +displayTitle: PMML model export - spark.mllib --- * Table of contents {:toc} -## MLlib supported models +## `spark.mllib` supported models -MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). +`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). -The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. +The table below outlines the `spark.mllib` models that can be exported to PMML and their equivalent PMML model.
    - + @@ -45,6 +45,8 @@ The table below outlines the MLlib models that can be exported to PMML and their
    To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. +Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. + Here a complete example of building a KMeansModel and print it out in PMML format: {% highlight scala %} import org.apache.spark.mllib.clustering.KMeans diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index be04d0b4b53a..652d215fa865 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - MLlib -displayTitle: MLlib - Basic Statistics +title: Basic Statistics - spark.mllib +displayTitle: Basic Statistics - spark.mllib --- * Table of contents @@ -38,6 +38,8 @@ available in `Statistics`. which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. +Refer to the [`MultivariateStatisticalSummary` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary) for details on the API. + {% highlight scala %} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} @@ -60,6 +62,8 @@ println(summary.numNonzeros) // number of nonzeros in each column which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. +Refer to the [`MultivariateStatisticalSummary` Java docs](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -86,6 +90,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. +Refer to the [`MultivariateStatisticalSummary` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary) for more details on the API. + {% highlight python %} from pyspark.mllib.stat import Statistics @@ -106,7 +112,7 @@ print(summary.numNonzeros()) ## Correlations -Calculating the correlation between two series of data is a common operation in Statistics. In MLlib +Calculating the correlation between two series of data is a common operation in Statistics. In `spark.mllib` we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently Pearson's and Spearman's correlation. @@ -116,6 +122,8 @@ correlation methods are currently Pearson's and Spearman's correlation. calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. +Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg._ @@ -144,6 +152,8 @@ val correlMatrix: Matrix = Statistics.corr(data, "pearson") calculate correlations between series. Depending on the type of input, two `JavaDoubleRDD`s or a `JavaRDD`, the output will be a `Double` or the correlation `Matrix` respectively. +Refer to the [`Statistics` Java docs](api/java/org/apache/spark/mllib/stat/Statistics.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -173,6 +183,8 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson"); calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. +Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. + {% highlight python %} from pyspark.mllib.stat import Statistics @@ -197,7 +209,7 @@ print(Statistics.corr(data, method="pearson")) ## Stratified sampling -Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, +Unlike the other statistics functions, which reside in `spark.mllib`, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified sampling, the keys can be thought of as a label and the value as a specific attribute. For example the key can be man or woman, or document ids, and the respective values can be the list of ages @@ -282,12 +294,12 @@ approxSample = data.sampleByKey(False, fractions); ## Hypothesis testing Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically -significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +significant, whether this result occurred by chance or not. `spark.mllib` currently supports Pearson's chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. -MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +`spark.mllib` also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests.
    @@ -338,6 +350,8 @@ featureTestResults.foreach { result => run Pearson's chi-squared tests. The following example demonstrates how to run and interpret hypothesis tests. +Refer to the [`ChiSqTestResult` Java docs](api/java/org/apache/spark/mllib/stat/test/ChiSqTestResult.html) for details on the API. + {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -385,6 +399,8 @@ for (ChiSqTestResult result : featureTestResults) { run Pearson's chi-squared tests. The following example demonstrates how to run and interpret hypothesis tests. +Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. + {% highlight python %} from pyspark import SparkContext from pyspark.mllib.linalg import Vectors, Matrices @@ -422,7 +438,7 @@ for i, result in enumerate(featureTestResults):
    -Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +Additionally, `spark.mllib` provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test for equality of probability distributions. By providing the name of a theoretical distribution (currently solely supported for the normal distribution) and its parameters, or a function to calculate the cumulative distribution according to a given theoretical distribution, the user can @@ -437,30 +453,104 @@ message. run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run and interpret the hypothesis tests. +Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. + {% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.stat.Statistics._ +import org.apache.spark.mllib.stat.Statistics val data: RDD[Double] = ... // an RDD of sample data // run a KS test for the sample versus a standard normal distribution val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) println(testResult) // summary of the test including the p-value, test statistic, - // and null hypothesis - // if our p-value indicates significance, we can reject the null hypothesis + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis // perform a KS test using a cumulative distribution function of our making val myCDF: Double => Double = ... val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) {% endhighlight %}
    + +
    +[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +Refer to the [`Statistics` Java docs](api/java/org/apache/spark/mllib/stat/Statistics.html) for details on the API. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; + +JavaSparkContext jsc = ... +JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, ...)); +KolmogorovSmirnovTestResult testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); +// summary of the test including the p-value, test statistic, +// and null hypothesis +// if our p-value indicates significance, we can reject the null hypothesis +System.out.println(testResult); +{% endhighlight %} +
    + +
    +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +parallelData = sc.parallelize([1.0, 2.0, ... ]) + +# run a KS test for the sample versus a standard normal distribution +testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) +print(testResult) # summary of the test including the p-value, test statistic, + # and null hypothesis + # if our p-value indicates significance, we can reject the null hypothesis +# Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with +# a lambda to calculate the CDF is not made available in the Python API +{% endhighlight %} +
    + + +### Streaming Significance Testing +`spark.mllib` provides online implementations of some tests to support use cases +like A/B testing. These tests may be performed on a Spark Streaming +`DStream[(Boolean,Double)]` where the first element of each tuple +indicates control group (`false`) or treatment group (`true`) and the +second element is the value of an observation. + +Streaming significance testing supports the following parameters: + +* `peacePeriod` - The number of initial data points from the stream to +ignore, used to mitigate novelty effects. +* `windowSize` - The number of past batches to perform hypothesis +testing over. Setting to `0` will perform cumulative processing using +all prior batches. + + +
    +
    +[`StreamingTest`](api/scala/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example scala/org/apache/spark/examples/mllib/StreamingTestExample.scala %} +
    ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +`spark.mllib` supports generating random RDDs with i.i.d. values drawn from a given distribution: uniform, standard normal, or Poisson.
    @@ -470,6 +560,8 @@ methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. +Refer to the [`RandomRDDs` Scala docs](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext import org.apache.spark.mllib.random.RandomRDDs._ @@ -490,6 +582,8 @@ methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. +Refer to the [`RandomRDDs` Java docs](api/java/org/apache/spark/mllib/random/RandomRDDs) for details on the API. + {% highlight java %} import org.apache.spark.SparkContext; import org.apache.spark.api.JavaDoubleRDD; @@ -516,6 +610,8 @@ methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. +Refer to the [`RandomRDDs` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) for more details on the API. + {% highlight python %} from pyspark.mllib.random import RandomRDDs @@ -523,10 +619,93 @@ sc = ... # SparkContext # Generate a random double RDD that contains 1 million i.i.d. values drawn from the # standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +u = RandomRDDs.normalRDD(sc, 1000000L, 10) # Apply a transform to get a random double RDD following `N(1, 4)`. v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %}
    + + +## Kernel density estimation + +[Kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) is a technique +useful for visualizing empirical probability distributions without requiring assumptions about the +particular distribution that the observed samples are drawn from. It computes an estimate of the +probability density function of a random variables, evaluated at a given set of points. It achieves +this estimate by expressing the PDF of the empirical distribution at a particular point as the the +mean of PDFs of normal distributions centered around each of the samples. + +
    + +
    +[`KernelDensity`](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +Refer to the [`KernelDensity` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) for details on the API. + +{% highlight scala %} +import org.apache.spark.mllib.stat.KernelDensity +import org.apache.spark.rdd.RDD + +val data: RDD[Double] = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +val kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0) + +// Find density estimates for the given values +val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/java/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +Refer to the [`KernelDensity` Java docs](api/java/org/apache/spark/mllib/stat/KernelDensity.html) for details on the API. + +{% highlight java %} +import org.apache.spark.mllib.stat.KernelDensity; +import org.apache.spark.rdd.RDD; + +RDD data = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +KernelDensity kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0); + +// Find density estimates for the given values +double[] densities = kd.estimate(new double[] {-1.0, 2.0, 5.0}); +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +Refer to the [`KernelDensity` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) for more details on the API. + +{% highlight python %} +from pyspark.mllib.stat import KernelDensity + +data = ... # an RDD of sample data + +# Construct the density estimator with the sample data and a standard deviation for the Gaussian +# kernels +kd = KernelDensity() +kd.setSample(data) +kd.setBandwidth(3.0) + +# Find density estimates for the given values +densities = kd.estimate([-1.0, 2.0, 5.0]) +{% endhighlight %} +
    diff --git a/docs/monitoring.md b/docs/monitoring.md index bcf885fe4e68..cedceb295802 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -48,7 +48,7 @@ follows: - + diff --git a/docs/programming-guide.md b/docs/programming-guide.md index ae712d62746f..f823b89a4b5e 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -34,8 +34,7 @@ To write a Spark application, you need to add a Maven dependency on Spark. Spark version = {{site.SPARK_VERSION}} In addition, if you wish to access an HDFS cluster, you need to add a dependency on -`hadoop-client` for your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +`hadoop-client` for your version of HDFS. groupId = org.apache.hadoop artifactId = hadoop-client @@ -66,8 +65,7 @@ To write a Spark application in Java, you need to add a dependency on Spark. Spa version = {{site.SPARK_VERSION}} In addition, if you wish to access an HDFS cluster, you need to add a dependency on -`hadoop-client` for your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +`hadoop-client` for your version of HDFS. groupId = org.apache.hadoop artifactId = hadoop-client @@ -85,16 +83,15 @@ import org.apache.spark.SparkConf
    -Spark {{site.SPARK_VERSION}} works with Python 2.6 or higher (but not Python 3). It uses the standard CPython interpreter, -so C libraries like NumPy can be used. +Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, +so C libraries like NumPy can be used. It also works with PyPy 2.3+. To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. You can also use `bin/pyspark` to launch an interactive Python shell. If you wish to access HDFS data, you need to use a build of PySpark linking -to your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +to your version of HDFS. [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. @@ -104,6 +101,14 @@ Finally, you need to import some Spark classes into your program. Add the follow from pyspark import SparkContext, SparkConf {% endhighlight %} +PySpark requires the same minor version of Python in both driver and workers. It uses the default python version in PATH, +you can specify which version of Python you want to use by `PYSPARK_PYTHON`, for example: + +{% highlight bash %} +$ PYSPARK_PYTHON=python3.4 bin/pyspark +$ PYSPARK_PYTHON=/opt/pypy-2.5/bin/pypy bin/spark-submit examples/src/main/python/pi.py +{% endhighlight %} +
    @@ -174,8 +179,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -209,7 +214,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -241,8 +246,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook. @@ -410,9 +415,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted:
    MLlib modelPMML model
    `spark.mllib` modelPMML model
    Environment VariableMeaning
    SPARK_DAEMON_MEMORYMemory to allocate to the history server (default: 512m).Memory to allocate to the history server (default: 1g).
    SPARK_DAEMON_JAVA_OPTS
    @@ -427,9 +432,9 @@ Writables are automatically converted:
    MapWritabledict
    -Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -446,7 +451,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -466,15 +471,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
    @@ -541,7 +546,7 @@ returning only its answer to the driver program. If we also wanted to use `lineLengths` again later, we could add: {% highlight java %} -lineLengths.persist(); +lineLengths.persist(StorageLevel.MEMORY_ONLY()); {% endhighlight %} before the `reduce`, which would cause `lineLengths` to be saved in memory after the first time it is computed. @@ -750,7 +755,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    @@ -769,7 +774,7 @@ println("Counter value: " + counter)
    {% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -795,19 +800,19 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. -However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. +However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on separate worker nodes each have their own copy of the closure. -What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only sees the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. +What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
    @@ -851,7 +856,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -966,7 +971,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
    Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1017,7 +1022,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1030,7 +1035,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1086,7 +1091,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1110,13 +1115,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1133,26 +1138,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1238,7 +1243,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1337,7 +1342,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1466,8 +1471,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
    -For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1478,7 +1483,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
    @@ -1545,7 +1550,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
    diff --git a/docs/quick-start.md b/docs/quick-start.md index ce2cc9d2169c..d481fe0ea6d7 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -126,7 +126,7 @@ scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (w wordCounts: spark.RDD[(String, Int)] = spark.ShuffledAggregatedRDD@71f027b8 {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight scala %} scala> wordCounts.collect() @@ -163,7 +163,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight python %} >>> wordCounts.collect() @@ -217,13 +217,13 @@ a cluster, as described in the [programming guide](programming-guide.html#initia
    # Self-Contained Applications -Now say we wanted to write a self-contained application using the Spark API. We will walk through a -simple application in both Scala (with SBT), Java (with Maven), and Python. +Suppose we wish to write a self-contained application using the Spark API. We will walk through a +simple application in Scala (with sbt), Java (with Maven), and Python.
    -We'll create a very simple Spark application in Scala. So simple, in fact, that it's +We'll create a very simple Spark application in Scala--so simple, in fact, that it's named `SimpleApp.scala`: {% highlight scala %} @@ -259,7 +259,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt` which explains that Spark is a dependency. This file also adds a repository that +`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -302,7 +302,7 @@ Lines with a: 46, Lines with b: 23
    -This example will use Maven to compile an application jar, but any similar build system will work. +This example will use Maven to compile an application JAR, but any similar build system will work. We'll create a very simple Spark application, `SimpleApp.java`: @@ -374,7 +374,7 @@ $ find . Now, we can package the application using Maven and execute it with `./bin/spark-submit`. {% highlight bash %} -# Package a jar containing your application +# Package a JAR containing your application $ mvn package ... [INFO] Building jar: {..}/{..}/target/simple-project-1.0.jar diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index debdd2adf22d..3193e1785348 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -45,7 +45,7 @@ frameworks. You can install Mesos either from source or using prebuilt packages To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a - [mirror](http://www.apache.org/dyn/closer.cgi/mesos/{{site.MESOS_VERSION}}/) + [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) 2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and installing Mesos @@ -150,30 +150,42 @@ it does not need to be redundantly passed in as a system property. Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client can find the results of the driver from the Mesos Web UI. -To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, -passing in the Mesos master url (e.g: mesos://host:5050). +To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url -to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL +to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. -# Mesos Run Modes +For example: +{% highlight bash %} +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 +{% endhighlight %} -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +# Mesos Run Modes + +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". + +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -184,13 +196,26 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support @@ -216,6 +241,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration @@ -229,7 +268,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -238,16 +277,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -262,7 +301,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the name of the docker image that the Spark executors will run in. The selected image must have Spark installed, as well as a compatible version of the Mesos library. The installed path of Spark in the image can be specified with spark.mesos.executor.home; - the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_JAVA_LIBRARY. @@ -271,7 +310,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
    [host_path:]container_path[:ro|:rw]
    @@ -302,27 +341,35 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + spark.mesos.principal - Framework principal to authenticate to Mesos + (none) Set the principal with which Spark framework will use to authenticate with Mesos. spark.mesos.secret - Framework secret to authenticate to Mesos + (none) Set the secret with which Spark framework will use to authenticate with Mesos. spark.mesos.role - Role for the Spark framework + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -330,7 +377,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints - Attribute based constraints to be matched against when accepting resource offers. + (none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d..06413f83c3a7 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -16,18 +16,19 @@ containers used by the application use the same configuration. If the configurat Java system properties or environment variables not managed by YARN, they should also be set in the Spark application's configuration (driver, executors, and the AM when running in client mode). -There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. +There are two deploy modes that can be used to launch Spark applications on YARN. In `cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. -To launch a Spark application in `yarn-cluster` mode: +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn`. + +To launch a Spark application in `cluster` mode: + + $ ./bin/spark-submit --class path.to.your.Class --master yarn --deploy-mode cluster [options] [app options] - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` - For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ + --master yarn \ + --deploy-mode cluster \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ @@ -37,16 +38,17 @@ For example: The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `client` mode, do the same, but replace `cluster` with `client`. The following shows how you can run `spark-shell` in `client` mode: - $ ./bin/spark-shell --master yarn-client + $ ./bin/spark-shell --master yarn --deploy-mode client ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ + --master yarn \ + --deploy-mode cluster \ --jars my-other-jar.jar,my-other-other-jar.jar my-main-jar.jar app_arg1 app_arg2 @@ -54,8 +56,8 @@ In `yarn-cluster` mode, the driver runs on a different machine than the client, # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration @@ -64,22 +66,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). -To use a custom log4j configuration for the application master or executors, there are two options: +To use a custom log4j configuration for the application master or executors, here are the options: - upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files to be uploaded with the application. @@ -87,12 +89,15 @@ To use a custom log4j configuration for the application master or executors, the (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, the `file:` protocol should be explicitly provided, and the file needs to exist locally on all the nodes. +- update the `$SPARK_CONF_DIR/log4j.properties` file and it will be automatically uploaded along + with the other configurations. Note that other 2 options has higher priority than this option if + multiple options are specified. Note that for the first option, both executors and the application master will share the same log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +105,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

      + Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +132,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +184,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +204,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead - driverMemory * 0.07, with minimum of 384 + driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead - AM memory * 0.07, with minimum of 384 + AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +252,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +271,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In cluster mode this controls + the environment of the Spark driver and in client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,33 +290,51 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.am.attemptFailuresValidityInterval + (none) + + Defines the validity interval for AM failure tracking. + If the AM has been running for at least the defined interval, the AM failure count will be reset. + This feature is not enabled if not configured, and only supported in Hadoop 2.6+. + + spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) @@ -319,20 +344,28 @@ If you need a reference to the proper location to put log files in the YARN so t running against earlier versions, this property will be ignored. + + spark.yarn.tags + (none) + + Comma-separated list of strings to pass through as YARN application tags appearing + in YARN ApplicationReports, which can be used for filtering when querying YARN apps. + + spark.yarn.keytab (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) spark.yarn.principal (none) - Principal to be used to login to KDC, while running on secure HDFS. + Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) @@ -361,11 +394,23 @@ If you need a reference to the proper location to put log files in the YARN so t See spark.yarn.config.gatewayPath. + + spark.yarn.security.tokens.${service}.enabled + true + + Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. + By default, delegation tokens for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run. +

      + Currently supported services are: hive, hbase + + # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. -- The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. +- In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a3..0bfc791c5744 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight @@ -131,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -139,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor @@ -150,14 +169,6 @@ configure those ports. Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager instead. - - Executor - Driver - (random) - Class file server - spark.replClassServer.port - Jetty-based. Only used in Spark shells. - Executor / Driver Executor / Driver diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 4f71fbc086cd..2fe9ec3542b2 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -152,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable SPARK_DAEMON_MEMORY - Memory to allocate to the Spark master and worker daemons themselves (default: 512m). + Memory to allocate to the Spark master and worker daemons themselves (default: 1g). SPARK_DAEMON_JAVA_OPTS diff --git a/docs/sparkr.md b/docs/sparkr.md index 4385a4eeacd5..9ddd2eda3fe8 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -28,13 +29,65 @@ All of the examples on this page use sample data included in R or the Spark dist The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name , any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, -which can be created from the SparkContext. If you are working from the SparkR shell, the -`SQLContext` and `SparkContext` should already be created for you. +which can be created from the SparkContext. If you are working from the `sparkR` shell, the +`SQLContext` and `SparkContext` should already be created for you, and you would not need to call +`sparkR.init`. +

      {% highlight r %} sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) {% endhighlight %} +
      + +## Starting Up from RStudio + +You can also start SparkR from RStudio. You can connect your R program to a Spark cluster from +RStudio, R shell, Rscript or other R IDEs. To start, make sure SPARK_HOME is set in environment +(you can check [Sys.getenv](https://stat.ethz.ch/R-manual/R-devel/library/base/html/Sys.getenv.html)), +load the SparkR package, and call `sparkR.init` as below. In addition to calling `sparkR.init`, you +could also specify certain Spark driver properties. Normally these +[Application properties](configuration.html#application-properties) and +[Runtime Environment](configuration.html#runtime-environment) cannot be set programmatically, as the +driver JVM process would have been started, in this case SparkR takes care of this for you. To set +them, pass them as you would other configuration properties in the `sparkEnvir` argument to +`sparkR.init()`. + +
      +{% highlight r %} +if (nchar(Sys.getenv("SPARK_HOME")) < 1) { + Sys.setenv(SPARK_HOME = "/home/spark") +} +library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))) +sc <- sparkR.init(master = "local[*]", sparkEnvir = list(spark.driver.memory="2g")) +{% endhighlight %} +
      + +The following options can be set in `sparkEnvir` with `sparkR.init` from RStudio: + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameProperty groupspark-submit equivalent
      spark.driver.memoryApplication Properties--driver-memory
      spark.driver.extraClassPathRuntime Environment--driver-class-path
      spark.driver.extraJavaOptionsRuntime Environment--driver-java-options
      spark.driver.extraLibraryPathRuntime Environment--driver-library-path
    @@ -42,11 +95,11 @@ sqlContext <- sparkRSQL.init(sc) With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R.
    {% highlight r %} -df <- createDataFrame(sqlContext, faithful) +df <- createDataFrame(sqlContext, faithful) # Displays the content of the DataFrame to stdout head(df) @@ -95,7 +148,7 @@ printSchema(people)
    The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` +to a Parquet file using `write.df` (Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API)
    {% highlight r %} @@ -138,7 +191,7 @@ Here we include some basic examples and a complete list can be found in the [API
    {% highlight r %} # Create the DataFrame -df <- createDataFrame(sqlContext, faithful) +df <- createDataFrame(sqlContext, faithful) # Get basic information about the DataFrame df @@ -151,7 +204,7 @@ head(select(df, df$eruptions)) ##2 1.800 ##3 3.333 -# You can also pass in column name as strings +# You can also pass in column name as strings head(select(df, "eruptions")) # Filter the DataFrame to only retain rows with wait times shorter than 50 mins @@ -165,7 +218,7 @@ head(filter(df, df$waiting < 50))
    -### Grouping, Aggregation +### Grouping, Aggregation SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below @@ -193,7 +246,7 @@ head(arrange(waiting_counts, desc(waiting_counts$count))) ### Operating on Columns -SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
    {% highlight r %} @@ -230,3 +283,114 @@ head(teenagers) {% endhighlight %}
    + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a gaussian GLM model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model summary are returned in a similar format to R's native glm(). +summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## +##$coefficients +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + + +# Migration Guide + +## Upgrading From SparkR 1.6 to 1.7 + + - Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ea77e82422f..3f9a831eddc8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Spark SQL and DataFrame Guide +displayTitle: Spark SQL, DataFrames and Datasets Guide title: Spark SQL and DataFrames --- @@ -9,18 +9,51 @@ title: Spark SQL and DataFrames # Overview -Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided +by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, +Spark SQL uses this extra information to perform extra optimizations. There are several ways to +interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +the same execution engine is used, independent of which API/language you are using to express the +computation. This unification means that developers can easily switch back and forth between the +various APIs based on which provides the most natural way to express a given transformation. -For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. +All of the examples on this page use sample data included in the Spark distribution and can be run in +the `spark-shell`, `pyspark` shell, or `sparkR` shell. -# DataFrames +## SQL -A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. +One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to +configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running +SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) +or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +## DataFrames -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R/Python, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such +as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), +[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), +[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +## Datasets + +A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of +RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's +optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated +using functional transformations (map, flatMap, filter, etc.). + +The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and +[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for +the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can +access the field of a row by name naturally `row.columnName`). Full python support will be added +in a future release. + +# Getting Started ## Starting Point: SQLContext @@ -28,8 +61,8 @@ All of the examples on this page use sample data included in the Spark distribut
    The entry point into all functionality in Spark SQL is the -[`SQLContext`](api/scala/index.html#org.apache.spark.sql.`SQLContext`) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} val sc: SparkContext // An existing SparkContext. @@ -45,7 +78,7 @@ import sqlContext.implicits._ The entry point into all functionality in Spark SQL is the [`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. @@ -58,7 +91,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); The entry point into all relational functionality in Spark is the [`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext @@ -70,7 +103,7 @@ sqlContext = SQLContext(sc)
    The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight r %} sqlContext <- sparkRSQL.init(sc) @@ -82,18 +115,18 @@ sqlContext <- sparkRSQL.init(sc) In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a superset of the functionality provided by the basic `SQLContext`. Additional features include the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an existing Hive setup, and all of the data sources available to a `SQLContext` are still available. `HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up to feature parity with a `HiveContext`. The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, this is recommended for most use cases. @@ -160,7 +193,7 @@ showDF(df) ## DataFrame Operations -DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/DataFrame.html). Here we include some basic examples of structured data processing using DataFrames: @@ -213,6 +246,11 @@ df.groupBy("age").count().show() // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). + +
    @@ -263,6 +301,10 @@ df.groupBy("age").count().show(); // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +
    @@ -320,6 +362,10 @@ df.groupBy("age").count().show() {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +
    @@ -370,10 +416,13 @@ showDF(count(groupBy(df, "age"))) {% endhighlight %} -
    +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html).
    +
    ## Running SQL Queries Programmatically @@ -382,14 +431,14 @@ The `sql` function on a `SQLContext` enables applications to run SQL queries pro
    {% highlight scala %} -val sqlContext = ... // An existing SQLContext +val sqlContext = ... // An existing SQLContext val df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    {% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext +SQLContext sqlContext = ... // An existing SQLContext DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    @@ -412,15 +461,54 @@ df <- sql(sqlContext, "SELECT * FROM table")
    +## Creating Datasets + +Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects +for processing or transmitting over the network. While both encoders and standard serialization are +responsible for turning an object into bytes, encoders are code generated dynamically and use a format +that allows Spark to perform many operations like filtering, sorting and hashing without deserializing +the bytes back into an object. + +
    +
    + +{% highlight scala %} +// Encoders for most common types are automatically provided by importing sqlContext.implicits._ +val ds = Seq(1, 2, 3).toDS() +ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) + +// Encoders are also created for case classes. +case class Person(name: String, age: Long) +val ds = Seq(Person("Andy", 32)).toDS() + +// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. +val path = "examples/src/main/resources/people.json" +val people = sqlContext.read.json(path).as[Person] + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} + +
    +
    + ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first -method uses reflection to infer the schema of an RDD that contains specific types of objects. This +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating DataFrames is through a programmatic interface that allows you to -construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection @@ -429,11 +517,11 @@ you to construct DataFrames when the columns and their types are not known until
    The Scala interface for Spark SQL supports automatically converting an RDD containing case classes -to a DataFrame. The case class -defines the schema of the table. The names of the arguments to the case class are read using +to a DataFrame. The case class +defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be -registered as a table. Tables can be used in subsequent SQL statements. +registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. @@ -470,9 +558,9 @@ teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(printl
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% highlight java %} @@ -543,9 +631,9 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
    -Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first +and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we plan to more completely infer the schema by looking at more data, similar to the inference that is performed on JSON files. @@ -764,7 +852,7 @@ for name in names.collect(): Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Registering a DataFrame as a table allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -818,9 +906,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") ### Manually Specifying Options You can also manually specify the data source that will be used along with any extra options -that you would like to pass to the data source. Data sources are specified by their fully qualified +that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    @@ -866,16 +954,53 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet")
    +### Run SQL on files directly + +Instead of using read API to load a file into DataFrame and query it, you can also query that +file directly with SQL. + +
    +
    + +{% highlight scala %} +val df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +DataFrame df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); +{% endhighlight %} +
    + +
    + +{% highlight python %} +df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    + +
    + +{% highlight r %} +df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    +
    + ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if -present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. -Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. - + @@ -907,7 +1032,7 @@ new data.
    Scala/JavaPythonMeaning
    Scala/JavaAny LanguageMeaning
    SaveMode.ErrorIfExists (default) "error" (default) Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
    @@ -915,21 +1040,22 @@ new data. ### Saving to Persistent Tables When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SQLContext` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically +be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +compatibility reasons. ### Loading Data Programmatically @@ -949,7 +1075,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.write.parquet("people.parquet") -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.read.parquet("people.parquet") @@ -971,7 +1097,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write().parquet("people.parquet"); -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); @@ -997,7 +1123,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write.parquet("people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.read.parquet("people.parquet") @@ -1021,7 +1147,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. saveAsParquetFile(schemaPeople, "people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile <- parquetFile(sqlContext, "people.parquet") @@ -1036,15 +1162,6 @@ for (teenName in collect(teenNames)) {
    -
    - -{% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.sql("REFRESH TABLE my_table") -{% endhighlight %} - -
    -
    {% highlight sql %} @@ -1065,10 +1182,10 @@ SELECT * FROM parquetTable ### Partition Discovery -Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in -the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For example, we can store all our previously used +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1110,20 +1227,34 @@ root {% endhighlight %} -Notice that the data types of the partitioning columns are automatically inferred. Currently, +Notice that the data types of the partitioning columns are automatically inferred. Currently, numeric data types and string type are supported. Sometimes users may not want to automatically infer the data types of the partitioning columns. For these use cases, the automatic type inference can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type inference is disabled, string type will be used for the partitioning columns. +Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +partitioning column. If users need to specify the base path that partition discovery +should start with, they can set `basePath` in the data source options. For example, +when `path/to/table/gender=male` is the path of the data and +users set `basePath` to `path/to/table/`, `gender` will be a partitioning column. ### Schema Merging -Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with -a simple schema, and gradually add more columns to the schema as needed. In this way, users may end -up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. +Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we +turned it off by default starting from 1.5.0. You may enable it by + +1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the + examples below), or +2. setting the global SQL option `spark.sql.parquet.mergeSchema` to `true`. +
    @@ -1143,7 +1274,7 @@ val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.read.parquet("data/test_table") +val df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1165,16 +1296,16 @@ df3.printSchema() # Create a simple DataFrame, stored into a partition directory df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ .map(lambda i: Row(single=i, double=i * 2))) -df1.save("data/test_table/key=1", "parquet") +df1.write.parquet("data/test_table/key=1") # Create another DataFrame in a new partition directory, # adding a new column and dropping an existing column df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) .map(lambda i: Row(single=i, triple=i * 3))) -df2.save("data/test_table/key=2", "parquet") +df2.write.parquet("data/test_table/key=2") # Read the partitioned table -df3 = sqlContext.load("data/test_table", "parquet") +df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1201,7 +1332,7 @@ saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") # Read the partitioned table -df3 <- loadDF(sqlContext, "data/test_table", "parquet") +df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema="true") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together @@ -1232,10 +1363,10 @@ processing. 1. Hive considers all columns nullable, while nullability in Parquet is significant Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a -Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: 1. Fields that have the same name in both schema must have the same data type regardless of - nullability. The reconciled field should have the data type of the Parquet side, so that + nullability. The reconciled field should have the data type of the Parquet side, so that nullability is respected. 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. @@ -1246,8 +1377,8 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation r #### Metadata Refreshing -Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table -conversion is enabled, metadata of those converted tables are also cached. If these tables are +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are updated by Hive or other external tools, you need to refresh them manually to ensure consistent metadata. @@ -1301,7 +1432,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.binaryAsString false - Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do + Some other Parquet-producing systems, in particular Impala, Hive, and older versions of Spark SQL, do not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1310,8 +1441,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also - store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1349,12 +1479,15 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a subclass of org.apache.parquet.hadoop.ParquetOutputCommitter.

    Note:

      +
    • + This option is automatically ignored if spark.speculation is turned on. +
    • This option must be set via Hadoop Configuration rather than Spark SQLConf. @@ -1371,6 +1504,16 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

      + + spark.sql.parquet.mergeSchema + false + +

      + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

      + + ## JSON Datasets @@ -1550,8 +1693,9 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), + `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running +the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. @@ -1563,9 +1707,11 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` in the current directory and +creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. +Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts +the spark application. {% highlight scala %} // sc is an existing SparkContext. @@ -1642,21 +1788,21 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL +will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). -Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` -and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive -jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the -version specified by users. An isolated classloader is used here to avoid dependency conflicts. +The following options can be used to configure the version of Hive that is used to retrieve metadata: - + @@ -1667,12 +1813,16 @@ version specified by users. An isolated classloader is used here to avoid depend property can be one of three options:
      1. builtin
      2. - Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be - either 0.13.1 or not defined. + either 1.2.1 or not defined.
      3. maven
      4. - Use Hive jars of specified version downloaded from Maven repositories. -
      5. A classpath in the standard format for both Hive and Hadoop.
      6. + Use Hive jars of specified version downloaded from Maven repositories. This configuration + is not generally recommended for production deployments. +
      7. A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be + present on the driver, but if you are running in yarn cluster mode then you must ensure + they are packaged with you application.
      @@ -1705,7 +1855,7 @@ version specified by users. An isolated classloader is used here to avoid depend ## JDBC To Other Databases -Spark SQL also includes a data source that can read data from other databases using JDBC. This +Spark SQL also includes a data source that can read data from other databases using JDBC. This functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). This is because the results are returned as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. @@ -1715,7 +1865,7 @@ provide a ClassTag. run queries using Spark SQL). To get started you will need to include the JDBC driver for you particular database on the -spark classpath. For example, to connect to postgres from the Spark Shell you would run the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: {% highlight bash %} @@ -1723,7 +1873,7 @@ SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. The following options are supported:
      Property NameDefaultMeaning
      spark.sql.hive.metastore.version0.13.11.2.1 Version of the Hive metastore. Available - options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. + options are 0.12.0 through 1.2.1.
      @@ -1736,8 +1886,8 @@ the Data Sources API. The following options are supported: @@ -1745,15 +1895,16 @@ the Data Sources API. The following options are supported: + + + + + +
      Property NameMeaning
      dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.
      driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem.
      partitionColumn, lowerBound, upperBound, numPartitions - These options must all be specified if any of them is specified. They describe how to + These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the @@ -1761,6 +1912,13 @@ the Data Sources API. The following options are supported: partitioned and returned.
      fetchSize + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). +
      @@ -1768,7 +1926,7 @@ the Data Sources API. The following options are supported:
      {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -1859,7 +2017,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ spark.sql.inMemoryColumnarStorage.batchSize 10000 - Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -1868,7 +2026,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ ## Other Configuration Options -The following options can also be used to tune the performance of query execution. It is possible +The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically. @@ -1878,17 +2036,17 @@ that these options will be deprecated in future release as more optimizations ar - + @@ -1898,13 +2056,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - - - -
      10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when - performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
      spark.sql.codegenspark.sql.tungsten.enabled true - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation.
      spark.sql.planner.externalSorttrue - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. -
      # Distributed SQL Engine @@ -1916,15 +2067,15 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to -specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of -all available options. By default, the server listens on localhost:10000. You may override this +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this behaviour via either environment variables, i.e.: {% highlight bash %} @@ -1957,7 +2108,7 @@ Beeline will ask you for a username and password. In non-secure mode, simply ent your machine and a blank password. For secure mode, please follow the instructions given in the [beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may also use the beeline script that comes with Hive. @@ -1982,12 +2133,54 @@ To start the Spark SQL CLI, run the following in the Spark directory: ./bin/spark-sql -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. # Migration Guide +## Upgrading From Spark SQL 1.5 to 1.6 + + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: + + {% highlight bash %} + ./sbin/start-thriftserver.sh \ + --conf spark.sql.hive.thriftServer.singleSession=true \ + ... + {% endhighlight %} + +## Upgrading From Spark SQL 1.4 to 1.5 + + - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with + code generation for expression evaluation. These features can both be disabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.parquet.mergeSchema` to `true`. + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + - In-memory columnar storage partition pruning is on by default. It can be disabled by setting + `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. + - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. + - Timestamps are now stored at a precision of 1us, rather than 1ns + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + unchanged. + - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe + and thus this output committer will not be used when speculation is on, independent of configuration. + - JSON data source will not automatically load new files that are created by other applications + (i.e. files that are not inserted to the dataset through Spark SQL). + For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), + users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method + to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate + the DataFrame and the new DataFrame will include new files. + ## Upgrading from Spark SQL 1.3 to 1.4 #### DataFrame data reader/writer interface @@ -2009,7 +2202,8 @@ See the API docs for `SQLContext.read` ( #### DataFrame.groupBy retains grouping columns -Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`. +Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the +grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
      @@ -2068,38 +2262,38 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false") ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the -available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other -releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked as unstable (i.e., DeveloperAPI or Experimental). #### Rename of SchemaRDD to DataFrame The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has -been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD directly, but instead provide most of the functionality that RDDs provide though their own -implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for -some use cases. It is still recommended that users update their code to use `DataFrame` instead. +some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. #### Unification of the Java and Scala APIs Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) -that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. #### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought -all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. Users should now write `import sqlContext.implicits._`. @@ -2107,7 +2301,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o case classes or tuples) with a method `toDF`, instead of applying automatically. When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import -`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: `import org.apache.spark.sql.functions._`. #### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) @@ -2146,7 +2340,7 @@ Python UDF registration is unchanged. When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of referencing a singleton. -## Migration Guide for Shark User +## Migration Guide for Shark Users ### Scheduling To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, @@ -2193,8 +2387,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses @@ -2222,6 +2418,7 @@ Spark SQL supports the vast majority of Hive features, such as: * User defined functions (UDF) * User defined aggregation functions (UDAF) * User defined serialization formats (SerDes) +* Window functions * Joins * `JOIN` * `{LEFT|RIGHT|FULL} OUTER JOIN` @@ -2232,7 +2429,7 @@ Spark SQL supports the vast majority of Hive features, such as: * `SELECT col FROM ( SELECT a + b AS col from t1) t2` * Sampling * Explain -* Partitioned tables +* Partitioned tables including dynamic partition insertion * View * All Hive DDL Functions, including: * `CREATE TABLE` @@ -2294,8 +2491,9 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. +# Reference -# Data Types +## Data Types Spark SQL and DataFrames support the following data types: @@ -2908,3 +3106,13 @@ from pyspark.sql.types import *
      +## NaN Semantics + +There is specially handling for not-a-number (NaN) when dealing with `float` or `double` types that +does not exactly match standard floating point semantics. +Specifically: + + - NaN = NaN returns true. + - In aggregations all NaN values are grouped together. + - NaN is treated as a normal value in join keys. + - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index de0461010dae..383d954409ce 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,8 +5,6 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. -Python API Flume is not yet available in the Python API. - ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879..5be73c42560f 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -5,7 +5,7 @@ title: Spark Streaming + Kafka Integration Guide [Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach -This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. +This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. @@ -74,7 +74,7 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. This approach has the following advantages over the receiver-based approach (i.e. Approach 1). @@ -82,7 +82,7 @@ This approach has the following advantages over the receiver-based approach (i.e - *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
      // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { @@ -181,7 +181,20 @@ Next, we discuss how to use this approach in your streaming application. );
      - Not supported yet + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges)
      diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index a7bcaec6fcd8..238a911a9199 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -91,7 +91,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - Kinesis data processing is ordered per partition and occurs at-least once per message. - - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. - A single Kinesis stream shard is processed by one input DStream at a time. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 4663b3f14c52..ed6b28c28213 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -50,13 +50,7 @@ all of which are presented in this guide. You will find tabs throughout this guide that let you choose between code snippets of different languages. -**Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream -transformations and almost all the output operations available in Scala and Java interfaces. -However, it only has support for basic sources like text files and text data over sockets. -APIs for additional sources, like Kafka and Flume, will be available in the future. -Further information about available features in the Python API are mentioned throughout this -document; look out for the tag -Python API. +**Note:** There are a few APIs that are either different or not available in Python. Throughout this guide, you will find the tag Python API highlighting these differences. *************************************************************************************************** @@ -683,7 +677,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, Kafka, Kinesis, Flume and MQTT are available in the Python API. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -725,11 +719,11 @@ Some of these advanced sources are as follows. - **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. +- **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a @@ -1141,7 +1135,7 @@ val joinedStream = stream1.join(stream2) {% highlight java %} JavaPairDStream stream1 = ... JavaPairDStream stream2 = ... -JavaPairDStream joinedStream = stream1.join(stream2); +JavaPairDStream> joinedStream = stream1.join(stream2); {% endhighlight %}
      @@ -1166,7 +1160,7 @@ val joinedStream = windowedStream1.join(windowedStream2) {% highlight java %} JavaPairDStream windowedStream1 = stream1.window(Durations.seconds(20)); JavaPairDStream windowedStream2 = stream2.window(Durations.minutes(1)); -JavaPairDStream joinedStream = windowedStream1.join(windowedStream2); +JavaPairDStream> joinedStream = windowedStream1.join(windowedStream2); {% endhighlight %}
      @@ -1702,7 +1696,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example +context and set up the DStreams. See the Java example [JavaRecoverableNetworkWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). This example appends the word counts of network data into a file. @@ -1813,7 +1807,7 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, +- *Configuring write ahead logs* - Since Spark 1.2, we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver @@ -1828,6 +1822,17 @@ To run a Spark Streaming applications, you need to have the following. stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. +- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming + application to process data as fast as it is being received, the receivers can be rate limited + by setting a maximum rate limit in terms of records / sec. + See the [configuration parameters](configuration.html#spark-streaming) + `spark.streaming.receiver.maxRate` for receivers and `spark.streaming.kafka.maxRatePerPartition` + for Direct Kafka approach. In Spark 1.5, we have introduced a feature called *backpressure* that + eliminate the need to set this rate limit, as Spark Streaming automatically figures out the + rate limits and dynamically adjusts them if the processing conditions change. This backpressure + can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) + `spark.streaming.backpressure.enabled` to `true`. + ### Upgrading Application Code {:.no_toc} @@ -1943,8 +1948,8 @@ unifiedStream.print(); {% highlight python %} numStreams = 5 kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] -unifiedStream = streamingContext.union(kafkaStreams) -unifiedStream.print() +unifiedStream = streamingContext.union(*kafkaStreams) +unifiedStream.pprint() {% endhighlight %}
    @@ -1996,8 +2001,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task - sizes, and therefore reduce the time taken to send them to the slaves. +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. This is controlled by the ```spark.closure.serializer``` property. However, at this time, Kryo serialization cannot be enabled for closure serialization. This may be resolved in a future release. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e52..acbb0f298fe4 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -65,8 +65,8 @@ For Python applications, simply pass a `.py` file in the place of ` Master URLMeaning - local Run Spark locally with one worker thread (i.e. no parallelism at all). - local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine). - local[*] Run Spark locally with as many worker threads as logical cores on your machine. - spark://HOST:PORT Connect to the given Spark standalone + local Run Spark locally with one worker thread (i.e. no parallelism at all). + local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine). + local[*] Run Spark locally with as many worker threads as logical cores on your machine. + spark://HOST:PORT Connect to the given Spark standalone cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. - mesos://HOST:PORT Connect to the given Mesos cluster. + mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... + To submit with --deploy-mode cluster, the HOST:PORT should be configured to connect to the MesosClusterDispatcher. + + yarn Connect to a YARN cluster in + client or cluster mode depending on the value of --deploy-mode. + The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. - yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. + yarn-client Equivalent to yarn with --deploy-mode client, + which is preferred to `yarn-client` - yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. + yarn-cluster Equivalent to yarn with --deploy-mode cluster, + which is preferred to `yarn-cluster` @@ -174,9 +192,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries diff --git a/docs/tuning.md b/docs/tuning.md index 572c7270e499..e73ed69ffbbf 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -61,8 +61,8 @@ The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes mor registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` -config property. The default is 2, but this value needs to be large enough to hold the *largest* -object you will serialize. +[config](configuration.html#compression-and-serialization). This value needs to be large enough +to hold the *largest* object you will serialize. Finally, if you don't register your custom classes, Kryo will still work, but it will have to store the full class name with each object, which is wasteful. @@ -88,9 +88,39 @@ than the "raw" data inside their fields. This is due to several reasons: but also pointers (typically 8 bytes each) to the next object in the list. * Collections of primitive types often store them as "boxed" objects such as `java.lang.Integer`. -This section will discuss how to determine the memory usage of your objects, and how to improve -it -- either by changing your data structures, or by storing data in a serialized format. -We will then cover tuning Spark's cache size and the Java garbage collector. +This section will start with an overview of memory management in Spark, then discuss specific +strategies the user can take to make more efficient use of memory in his/her application. In +particular, we will describe how to determine the memory usage of your objects, and how to +improve it -- either by changing your data structures, or by storing data in a serialized +format. We will then cover tuning Spark's cache size and the Java garbage collector. + +## Memory Management Overview + +Memory usage in Spark largely falls under one of two categories: execution and storage. +Execution memory refers to that used for computation in shuffles, joins, sorts and aggregations, +while storage memory refers to that used for caching and propagating internal data across the +cluster. In Spark, execution and storage share a unified region (M). When no execution memory is +used, storage can acquire all the available memory and vice versa. Execution may evict storage +if necessary, but only until total storage memory usage falls under a certain threshold (R). +In other words, `R` describes a subregion within `M` where cached blocks are never evicted. +Storage may not evict execution due to complexities in implementation. + +This design ensures several desirable properties. First, applications that do not use caching +can use the entire space for execution, obviating unnecessary disk spills. Second, applications +that do use caching can reserve a minimum storage space (R) where their data blocks are immune +to being evicted. Lastly, this approach provides reasonable out-of-the-box performance for a +variety of workloads without requiring user expertise of how memory is divided internally. + +Although there are two relevant configurations, the typical user should not need to adjust them +as the default values are applicable to most workloads: + +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) +(default 0.75). The rest of the space (25%) is reserved for user data structures, internal +metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually +large records. +* `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5). +`R` is the storage space within `M` where cached blocks immune to being evicted by execution. + ## Determining Memory Consumption @@ -151,18 +181,6 @@ time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+ each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in their work directories), *not* on your driver program. -**Cache Size Tuning** - -One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. -By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to -cache RDDs. This means that 40% of memory is available for any objects created during task execution. - -In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of -memory, lowering this value will help reduce the memory consumption. To change this to, say, 50%, you can call -`conf.set("spark.storage.memoryFraction", "0.5")` on your SparkConf. Combined with the use of serialized caching, -using a smaller cache should be sufficient to mitigate most of the garbage collection problems. -In case you are interested in further tuning the Java GC, continue reading below. - **Advanced GC Tuning** To further tune garbage collection, we first need to understand some basic information about memory management in the JVM: @@ -183,9 +201,9 @@ temporary objects created during task execution. Some steps which may be useful * Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for before a task completes, it means that there isn't enough memory available for executing tasks. -* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching. - This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow - down task execution! +* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of + memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer + objects than to slow down task execution! * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden @@ -240,7 +258,7 @@ worth optimizing. ## Data Locality Data locality can have a major impact on the performance of Spark jobs. If data and the code that -operates on it are together than computation tends to be fast. But if code and data are separated, +operates on it are together then computation tends to be fast. But if code and data are separated, one must move to the other. Typically it is faster to ship serialized code from place to place than a chunk of data because code size is much smaller than data. Spark builds its scheduling around this general principle of data locality. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ccf922d9371f..19d5980560fe 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.4.0" +SPARK_EC2_VERSION = "1.6.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -71,6 +71,11 @@ "1.3.0", "1.3.1", "1.4.0", + "1.4.1", + "1.5.0", + "1.5.1", + "1.5.2", + "1.6.0", ]) SPARK_TACHYON_MAP = { @@ -84,14 +89,19 @@ "1.3.0": "0.5.0", "1.3.1": "0.5.0", "1.4.0": "0.6.4", + "1.4.1": "0.6.4", + "1.5.0": "0.7.1", + "1.5.1": "0.7.1", + "1.5.2": "0.7.1", + "1.6.0": "0.8.2", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" # Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" def setup_external_libs(libs): @@ -177,6 +187,10 @@ def parse_args(): parser.add_option( "-i", "--identity-file", help="SSH private key file to use for logging into instances") + parser.add_option( + "-p", "--profile", default=None, + help="If you have multiple profiles (AWS or boto config), you can configure " + + "additional, named profiles by using this option (default: %default)") parser.add_option( "-t", "--instance-type", default="m1.large", help="Type of instance to launch (default: %default). " + @@ -587,7 +601,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves @@ -1234,6 +1248,10 @@ def get_ip_address(instance, private_ips=False): def get_dns_name(instance, private_ips=False): dns = instance.public_dns_name if not private_ips else \ instance.private_ip_address + if not dns: + raise UsageError("Failed to determine hostname of {0}.\n" + "Please check that you provided --private-ips if " + "necessary".format(instance)) return dns @@ -1311,7 +1329,10 @@ def real_main(): sys.exit(1) try: - conn = ec2.connect_to_region(opts.region) + if opts.profile is None: + conn = ec2.connect_to_region(opts.region) + else: + conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) except Exception as e: print((e), file=stderr) sys.exit(1) diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca9..f5ab2a7fdc09 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java new file mode 100644 index 000000000000..69a174562fcf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.regression.AFTSurvivalRegression; +import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaAFTSurvivalRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)), + RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)), + RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)), + RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)), + RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + DataFrame training = jsql.createDataFrame(data, schema); + double[] quantileProbabilities = new double[]{0.3, 0.6}; + AFTSurvivalRegression aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles"); + + AFTSurvivalRegressionModel model = aft.fit(training); + + // Print the coefficients, intercept and scale parameter for AFT survival regression + System.out.println("Coefficients: " + model.coefficients() + " Intercept: " + + model.intercept() + " Scale: " + model.scale()); + model.transform(training).show(false); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 000000000000..1eda1f694fc2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 000000000000..8ad369cc93e8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + bucketedData.show(); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java new file mode 100644 index 000000000000..ede05d6e2011 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.ChiSqSelector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + ChiSqSelector selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures"); + + DataFrame result = selector.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java new file mode 100644 index 000000000000..ac33adb65292 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.CountVectorizer; +import org.apache.spark.ml.feature.CountVectorizerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaCountVectorizerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaCountVectorizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("a", "b", "c")), + RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) + )); + StructType schema = new StructType(new StructField [] { + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + // fit a CountVectorizerModel from the corpus + CountVectorizerModel cvModel = new CountVectorizer() + .setInputCol("text") + .setOutputCol("feature") + .setVocabSize(3) + .setMinDF(2) + .fit(df); + + // alternatively, define CountVectorizerModel with a-priori vocabulary + CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) + .setInputCol("text") + .setOutputCol("feature"); + + cvModel.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 000000000000..35c0d534a45e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java new file mode 100644 index 000000000000..482225e585cf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeClassificationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + + // Automatically identify categorical features, and index them. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel) (model.stages()[2]); + System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java new file mode 100644 index 000000000000..c7f1868dd105 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{featureIndexer, dt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel) (model.stages()[1]); + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca577..b9dd3ad95771 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception { } if (sumPredictions != 0.0) { throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all weights are 0!"); + " even though all coefficients are 0!"); } jsc.stop(); @@ -124,7 +124,7 @@ public String uid() { /** * Param for max number of iterations - *

    + *

    * NOTE: The usual way to add a parameter to a model or algorithm is to include: * - val myParamName: ParamType * - def getMyParamName @@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); } @Override @@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) { class MyJavaLogisticRegressionModel extends ClassificationModel { - private Vector weights_; - public Vector weights() { return weights_; } + private Vector coefficients_; + public Vector coefficients() { return coefficients_; } - public MyJavaLogisticRegressionModel(String uid, Vector weights) { + public MyJavaLogisticRegressionModel(String uid, Vector coefficients) { this.uid_ = uid; - this.weights_ = weights; + this.coefficients_ = coefficients; } private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); @@ -208,7 +208,7 @@ public String uid() { * modifier. */ public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, weights_); + double margin = BLAS.dot(features, coefficients_); // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). return Vectors.dense(-margin, margin); @@ -219,10 +219,15 @@ public Vector predictRaw(Vector features) { */ public int numClasses() { return 2; } + /** + * Number of features the model was trained on. + */ + public int numFeatures() { return coefficients_.size(); } + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

    + *

    * This is used for the defaul implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected @@ -230,6 +235,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 000000000000..2898accec61b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java new file mode 100644 index 000000000000..848fe6566c1e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.GBTClassificationModel; +import org.apache.spark.ml.classification.GBTClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a GBT model. + GBTClassifier gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and GBT in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]); + System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java new file mode 100644 index 000000000000..1f67b0842db0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.GBTRegressionModel; +import org.apache.spark.ml.regression.GBTRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a GBT model. + GBTRegressor gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Chain indexer and GBT in a Pipeline + Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + GBTRegressionModel gbtModel = (GBTRegressionModel)(model.stages()[1]); + System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java new file mode 100644 index 000000000000..3ccd6993261e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.IndexToString; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaIndexToStringExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + IndexToString converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory"); + DataFrame converted = converter.transform(indexed); + converted.select("id", "originalCategory").show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465..96481d882a5d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -23,6 +23,9 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +// $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; @@ -30,18 +33,17 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +// $example off$ /** * An example demonstrating a k-means clustering. * Run with *

    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { @@ -74,6 +76,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); + // $example on$ // Loads data JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; @@ -91,6 +94,7 @@ public static void main(String[] args) { for (Vector center: centers) { System.out.println(center); } + // $example off$ jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 000000000000..3a5d3237c85f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; +// $example on$ +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + // $example on$ + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } + // $example off$ +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java new file mode 100644 index 000000000000..4ad7676c8d32 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for linear regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java new file mode 100644 index 000000000000..986f3b3b28d7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.functions; +// $example off$ + +public class JavaLogisticRegressionSummaryExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionSummaryExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary + // classification problem. + BinaryLogisticRegressionSummary binarySummary = + (BinaryLogisticRegressionSummary) trainingSummary; + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + DataFrame roc = binarySummary.roc(); + roc.show(); + roc.select("FPR").show(); + System.out.println(binarySummary.areaUnderROC()); + + // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with + // this selected threshold. + DataFrame fMeasure = binarySummary.fMeasureByThreshold(); + double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); + double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) + .select("threshold").head().getDouble(0); + lrModel.setThreshold(bestThreshold); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java new file mode 100644 index 000000000000..1d28279d72a0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for logistic regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 000000000000..2d50ba7faa1a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java new file mode 100644 index 000000000000..84369f6681d0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.sql.DataFrame; +// $example off$ + +/** + * An example for Multilayer Perceptron Classification. + */ +public class JavaMultilayerPerceptronClassifierExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + String path = "data/mllib/sample_multiclass_classification_data.txt"; + DataFrame dataFrame = jsql.read().format("libsvm").load(path); + // Split the data into train and test + DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + DataFrame train = splits[0]; + DataFrame test = splits[1]; + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + int[] layers = new int[] {4, 5, 4, 3}; + // create the trainer and set its parameters + MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); + // train the model + MultilayerPerceptronClassificationModel model = trainer.fit(train); + // compute precision on the test set + DataFrame result = model.transform(test); + DataFrame predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 000000000000..8fd75ed8b5f4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 000000000000..ed3f6163c055 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + l1NormData.show(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 000000000000..bc509607084b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + encoded.select("id", "categoryVec").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index 75063dbf800d..42374e77acf0 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -21,18 +21,18 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +// $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.ml.util.MetadataUtils; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; +// $example off$ /** * An example runner for Multiclass to Binary Reduction with One Vs Rest. @@ -63,6 +63,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); + // $example on$ // configure the base classifier LogisticRegression classifier = new LogisticRegression() .setMaxIter(params.maxIter) @@ -80,31 +81,30 @@ public static void main(String[] args) { OneVsRest ovr = new OneVsRest().setClassifier(classifier); String input = params.input; - RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); - RDD train; - RDD test; + DataFrame inputData = jsql.read().format("libsvm").load(input); + DataFrame train; + DataFrame test; // compute the train/ test split: if testInput is not provided use part of input String testInput = params.testInput; if (testInput != null) { train = inputData; // compute the number of features in the training set. - int numFeatures = inputData.first().features().size(); - test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + int numFeatures = inputData.first().getAs(1).size(); + test = jsql.read().format("libsvm").option("numFeatures", + String.valueOf(numFeatures)).load(testInput); } else { double f = params.fracTest; - RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); train = tmp[0]; test = tmp[1]; } // train the multiclass model - DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); - OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + OneVsRestModel ovrModel = ovr.fit(train.cache()); // score the model on test data - DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); - DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + DataFrame predictions = ovrModel.transform(test.cache()) .select("prediction", "label"); // obtain metrics @@ -128,6 +128,7 @@ public static void main(String[] args) { System.out.println(confusionMatrix); System.out.println(); System.out.println(results); + // $example off$ jsc.stop(); } @@ -178,6 +179,7 @@ private static Params parse(String[] args) { return params; } + @SuppressWarnings("static") private static Options generateCommandlineOptions() { Option input = OptionBuilder.withArgName("input") .hasArg() diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 000000000000..8282fab084f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 000000000000..668f71e64056 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java new file mode 100644 index 000000000000..251ae79d9a10 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.QuantileDiscretizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaQuantileDiscretizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize( + Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) + ) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + QuantileDiscretizer discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3); + + DataFrame result = discretizer.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 000000000000..1e1062b541ad --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java new file mode 100644 index 000000000000..5a6249666029 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.RandomForestClassificationModel; +import org.apache.spark.ml.classification.RandomForestClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a RandomForest model. + RandomForestClassifier rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]); + System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java new file mode 100644 index 000000000000..05782a0724a7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.RandomForestRegressionModel; +import org.apache.spark.ml.regression.RandomForestRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a RandomForest model. + RandomForestRegressor rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, rf}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]); + System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java new file mode 100644 index 000000000000..d55c70796a96 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.SQLTransformer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaSQLTransformerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 1.0, 3.0), + RowFactory.create(2, 2.0, 5.0) + )); + StructType schema = new StructType(new StructField [] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + SQLTransformer sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); + + sqlTrans.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index dac649d1d5ae..ea83e8fef9eb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,8 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + double[] thresholds = {0.45, 0.55}; + paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 000000000000..da4756643f3c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 000000000000..b6b201c6b68d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 000000000000..05d12c1e702f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java new file mode 100644 index 000000000000..a41a5ec9bff0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.IDF; +import org.apache.spark.ml.feature.IDFModel; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTfIdfExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTfIdfExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(0, "I wish Java could use case classes"), + RowFactory.create(1, "Logistic regression models are neat") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); + int numFeatures = 20; + HashingTF hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("rawFeatures") + .setNumFeatures(numFeatures); + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { + Vector features = r.getAs(0); + Double label = r.getDouble(1); + System.out.println(features); + System.out.println(label); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 000000000000..617dc3f66e3b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java new file mode 100644 index 000000000000..d433905fc801 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaTrainValidationSplitExample + * }}} + */ +public class JavaTrainValidationSplitExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8); + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 000000000000..7e230b5897c1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 000000000000..545758e31d97 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + indexedData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 000000000000..4d5cb04ff5e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java new file mode 100644 index 000000000000..d472375ca982 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.Word2Vec; +import org.apache.spark.ml.feature.Word2VecModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaWord2VecExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaWord2VecExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + // Learn a mapping from words to Vectors. + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + for (Row r : result.select("result").take(3)) { + System.out.println(r); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java new file mode 100644 index 000000000000..4d0f989819ac --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; +// $example off$ + +import org.apache.spark.SparkConf; + +public class JavaAssociationRulesExample { + + public static void main(String[] args) { + + SparkConf sparkConf = new SparkConf().setAppName("JavaAssociationRulesExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) + )); + + AssociationRules arules = new AssociationRules() + .setMinConfidence(0.8); + JavaRDD> results = arules.run(freqItemsets); + + for (AssociationRules.Rule rule : results.collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 000000000000..980a9108af53 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java new file mode 100644 index 000000000000..0001500f4fa5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +// $example on$ +import com.google.common.collect.Lists; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.BisectingKMeans; +import org.apache.spark.mllib.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaBisectingKMeansExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + ArrayList localData = Lists.newArrayList( + Vectors.dense(0.1, 0.1), Vectors.dense(0.3, 0.3), + Vectors.dense(10.1, 10.1), Vectors.dense(10.3, 10.3), + Vectors.dense(20.1, 20.1), Vectors.dense(20.3, 20.3), + Vectors.dense(30.1, 30.1), Vectors.dense(30.3, 30.3) + ); + JavaRDD data = sc.parallelize(localData, 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4); + BisectingKMeansModel model = bkm.run(data); + + System.out.println("Compute Cost: " + model.computeCost(data)); + for (Vector center: model.clusterCenters()) { + System.out.println(""); + } + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java deleted file mode 100644 index 1f82e3f4cb18..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.HashMap; - -import scala.Tuple2; - -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -/** - * Classification and regression using decision trees. - */ -public final class JavaDecisionTree { - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - if (args.length == 1) { - datapath = args[0]; - } else if (args.length > 1) { - System.err.println("Usage: JavaDecisionTree "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - - // Set parameters. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - - // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - - // Train a DecisionTree model for regression. - impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD regressorPredictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(regressionModel.predict(p.features()), p.label()); - } - }); - Double trainMSE = - regressorPredictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + regressionModel); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java new file mode 100644 index 000000000000..5839b0cf8a8f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeClassificationExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java new file mode 100644 index 000000000000..ccde578249f7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeRegressionExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "variance"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model. + final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java deleted file mode 100644 index a1844d5d07ad..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTreesRunner { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTreesRunner " + - " "); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.treeStrategy().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.treeStrategy().setNumClasses(numClasses); - - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java new file mode 100644 index 000000000000..80faabd2325d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setNumClasses(2); + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingClassificationModel"); + // $example off$ + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java new file mode 100644 index 000000000000..216895b36820 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java new file mode 100644 index 000000000000..37e709b4cbc0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; +import scala.Tuple3; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.regression.IsotonicRegression; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaIsotonicRegressionExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaIsotonicRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // $example on$ + JavaRDD data = jsc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + + // Create label, feature, weight tuples from input data with weight set to default value 1.0. + JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } + ); + + // Split data into training (60%) and test (40%) sets. + JavaRDD>[] splits = parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD> training = splits[0]; + JavaRDD> test = splits[1]; + + // Create isotonic regression model from training data. + // Isotonic parameter defaults to true so it is only shown for demonstration + final IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + + // Create tuples of predicted and real labels. + JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override + public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } + ); + + // Calculate mean squared error between predicted and real labels. + Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override + public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + meanSquaredError); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); + IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java new file mode 100644 index 000000000000..355883f61bd6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example off$ + +public class JavaLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + // $example off$ + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc497..de8e739ac925 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 000000000000..5ba01e0d0881 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +// $example off$ +import org.apache.spark.SparkContext; + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision( + metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure( + metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 000000000000..5247c9c74861 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( + metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( + metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java new file mode 100644 index 000000000000..e6a5904bd71f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.NaiveBayes; +import org.apache.spark.mllib.classification.NaiveBayesModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaNaiveBayesExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaNaiveBayesExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // $example on$ + String path = "data/mllib/sample_naive_bayes_data.txt"; + JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD(); + JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4}, 12345); + JavaRDD training = tmp[0]; // training set + JavaRDD test = tmp[1]; // test set + final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); + JavaPairRDD predictionAndLabel = + test.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + double accuracy = predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return pl._1().equals(pl._2()); + } + }).count() / (double) test.count(); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myNaiveBayesModel"); + NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/myNaiveBayesModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java new file mode 100644 index 000000000000..68ec7c1e6ebe --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; +// $example off$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.mllib.fpm.PrefixSpan; +import org.apache.spark.mllib.fpm.PrefixSpanModel; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaPrefixSpanExample { + + public static void main(String[] args) { + + SparkConf sparkConf = new SparkConf().setAppName("JavaPrefixSpanExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { + System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java new file mode 100644 index 000000000000..9219eef1ad2d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaRandomForestClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + HashMap categoricalFeaturesInfo = new HashMap(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + Integer seed = 12345; + + final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, + seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java deleted file mode 100644 index 89a4e092a5af..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import java.util.HashMap; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -public final class JavaRandomForestExample { - - /** - * Note: This example illustrates binary classification. - * For information on multiclass classification, please refer to the JavaDecisionTree.java - * example. - */ - private static void testClassification(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "gini"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); - System.out.println("Test Error: " + testErr); - System.out.println("Learned classification forest model:\n" + model.toDebugString()); - } - - private static void testRegression(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); - System.out.println("Test Mean Squared Error: " + testMSE); - System.out.println("Learned regression forest model:\n" + model.toDebugString()); - } - - public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - // Load and parse the data file. - String datapath = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); - // Split the data into training and test sets (30% held out for testing) - JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); - JavaRDD trainingData = splits[0]; - JavaRDD testData = splits[1]; - - System.out.println("\nRunning example of classification using RandomForest\n"); - testClassification(trainingData, testData); - - System.out.println("\nRunning example of regression using RandomForest\n"); - testRegression(trainingData, testData); - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java new file mode 100644 index 000000000000..4db926a4218f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRandomForestRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "variance"; + Integer maxDepth = 4; + Integer maxBins = 32; + Integer seed = 12345; + // Train a RandomForest model. + final RandomForestModel model = RandomForest.trainRegressor(trainingData, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / testData.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 000000000000..47ab3fc35824 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join( + userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java new file mode 100644 index 000000000000..c179e7578cdf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRecommendationExample { + public static void main(String[] args) { + // $example on$ + SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = jsc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 10; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myCollaborativeFilter"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(), + "target/tmp/myCollaborativeFilter"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 000000000000..4e89dd0c37c5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) { + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + } + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java new file mode 100644 index 000000000000..72edaca5e95b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +// $example off$ +import org.apache.spark.api.java.function.Function; +// $example on$ +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; +// $example off$ + +import org.apache.spark.SparkConf; + +public class JavaSimpleFPGrowth { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("FP-growth Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // $example on$ + JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); + + JavaRDD> transactions = data.map( + new Function>() { + public List call(String line) { + String[] parts = line.split(" "); + return Arrays.asList(parts); + } + } + ); + + FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); + FPGrowthModel model = fpg.run(transactions); + + for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); + } + + double minConfidence = 0.8; + for (AssociationRules.Rule rule + : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e..4b50fbf59f80 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index bab9f2478e77..f9a5e7f69ffe 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -35,12 +35,12 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: DirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers * is a list of one or more kafka topics to consume from * * Example: - * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 */ public final class JavaDirectKafkaWordCount { @@ -48,7 +48,7 @@ public final class JavaDirectKafkaWordCount { public static void main(String[] args) { if (args.length < 2) { - System.err.println("Usage: DirectKafkaWordCount \n" + + System.err.println("Usage: JavaDirectKafkaWordCount \n" + " is a list of one or more Kafka brokers\n" + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); @@ -59,7 +59,7 @@ public static void main(String[] args) { String brokers = args[0]; String topics = args[1]; - // Create context with 2 second batch interval + // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 16ae9a3319ee..337f8ffb5bfb 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -66,7 +66,7 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); - // Create the context with a 1 second batch size + // Create the context with 2 seconds batch size JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); int numThreads = Integer.parseInt(args[3]); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb5..3515d7be45d3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index dbf2ef02d7b7..14997c64d505 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -26,18 +26,15 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -45,7 +42,7 @@ * Usage: JavaStatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. - *

    + *

    * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example @@ -63,29 +60,16 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2, Optional, Optional> updateFunction = - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial state RDD input to mapWithState @SuppressWarnings("unchecked") List> tuples = Arrays.asList(new Tuple2("hello", 1), new Tuple2("world", 1)); - JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); JavaReceiverInputDStream lines = ssc.socketTextStream( args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); @@ -105,9 +89,22 @@ public Tuple2 call(String s) { } }); - // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + // Update the cumulative count function + final Function3, State, Tuple2> mappingFunc = + new Function3, State, Tuple2>() { + + @Override + public Tuple2 call(String word, Optional one, State state) { + int sum = one.or(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2(word, sum); + state.update(sum); + return output; + } + }; + + // DStream made of get cumulative counts that get updated in every batch + JavaMapWithStateDStream> stateDstream = + wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e9..205ca02962be 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py new file mode 100644 index 000000000000..0ee01fd8258d --- /dev/null +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import AFTSurvivalRegression +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="AFTSurvivalRegressionExample") + sqlContext = SQLContext(sc) + + # $example on$ + training = sqlContext.createDataFrame([ + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"]) + quantileProbabilities = [0.3, 0.6] + aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities, + quantilesCol="quantiles") + + model = aft.fit(training) + + # Print the coefficients, intercept and scale parameter for AFT survival regression + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + print("Scale: " + str(model.scale)) + model.transform(training).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 000000000000..317cfa638a5a --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 000000000000..4304255f350d --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/ml/dataframe_example.py similarity index 53% rename from examples/src/main/python/mllib/dataset_example.py rename to examples/src/main/python/ml/dataframe_example.py index e23ecc0c5d30..d2644ca33565 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -16,8 +16,8 @@ # """ -An example of how to use DataFrame as a dataset for ML. Run with:: - bin/spark-submit examples/src/main/python/mllib/dataset_example.py +An example of how to use DataFrame for ML. Run with:: + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -28,36 +28,48 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.mllib.util import MLUtils from pyspark.mllib.stat import Statistics - -def summarize(dataset): - print("schema: %s" % dataset.schema().json()) - labels = dataset.map(lambda r: r.label) - print("label average: %f" % labels.mean()) - features = dataset.map(lambda r: r.features) - summary = Statistics.colStats(features) - print("features average: %r" % summary.mean()) - if __name__ == "__main__": if len(sys.argv) > 2: - print("Usage: dataset_example.py ", file=sys.stderr) + print("Usage: dataframe_example.py ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="DatasetExample") + sc = SparkContext(appName="DataFrameExample") sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" - points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() - summarize(dataset0) + + # Load input data + print("Loading LIBSVM file with UDT from " + input + ".") + df = sqlContext.read.format("libsvm").load(input).cache() + print("Schema from LIBSVM:") + df.printSchema() + print("Loaded training data as a DataFrame with " + + str(df.count()) + " records.") + + # Show statistical summary of labels. + labelSummary = df.describe("label") + labelSummary.show() + + # Convert features column to an RDD of vectors. + features = df.select("features").map(lambda r: r.features) + summary = Statistics.colStats(features) + print("Selected features column with average values:\n" + + str(summary.mean())) + + # Save the records in a parquet file. tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print("Save dataset as a Parquet file to %s." % tempdir) - dataset0.saveAsParquetFile(tempdir) - print("Load it back and summarize it again.") - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() - summarize(dataset1) + print("Saving to " + tempdir + " as Parquet file.") + df.write.parquet(tempdir) + + # Load the records back. + print("Loading Parquet file with UDT from " + tempdir) + newDF = sqlContext.read.parquet(tempdir) + print("Schema from Parquet:") + newDF.printSchema() shutil.rmtree(tempdir) + + sc.stop() diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py new file mode 100644 index 000000000000..8cda56dbb9bd --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +import sys + +# $example on$ +from pyspark import SparkContext, SQLContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and tree in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g " % (1.0 - accuracy)) + + treeModel = model.stages[2] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py new file mode 100644 index 000000000000..439e39894749 --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + + # Chain indexer and tree in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, dt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + treeModel = model.stages[1] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 000000000000..c85cb0d89543 --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py new file mode 100644 index 000000000000..028497651fbf --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Classifier Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) + + # Chain indexers and GBT in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + gbtModel = model.stages[2] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py new file mode 100644 index 000000000000..4246e133a903 --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Regressor Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) + + # Chain indexer and GBT in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, gbt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + gbtModel = model.stages[1] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index 6446f0fe5eea..000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,83 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py new file mode 100644 index 000000000000..fb0ba2950bbd --- /dev/null +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import IndexToString, StringIndexer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="IndexToStringExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + + converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory") + converted = converter.transform(indexed) + + converted.select("id", "originalCategory").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py new file mode 100644 index 000000000000..a4cd40cf2672 --- /dev/null +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import LinearRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LinearRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") + + lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for linear regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 55afe1b207fe..000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.mllib.util import MLUtils -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py new file mode 100644 index 000000000000..b0b1d27e13bb --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LogisticRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for logistic regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/multilayer_perceptron_classification.py b/examples/src/main/python/ml/multilayer_perceptron_classification.py new file mode 100644 index 000000000000..f84588f547ff --- /dev/null +++ b/examples/src/main/python/ml/multilayer_perceptron_classification.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import MultilayerPerceptronClassifier +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="multilayer_perceptron_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_multiclass_classification_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + # specify layers for the neural network: + # input layer of size 4 (features), two intermediate of size 5 and 4 + # and output of size 3 (classes) + layers = [4, 5, 4, 3] + # create the trainer and set its parameters + trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234) + # train the model + model = trainer.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 000000000000..f2d85f53e721 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 000000000000..d490221474c2 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + lInfNormData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 000000000000..0f94c26638d3 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 000000000000..a17181f1b8a5 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data, ["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 000000000000..89f5cbe8f2f4 --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext\ + .createDataFrame([(Vectors.dense([-2.0, 2.3]),), + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.6, -1.1]),)], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py new file mode 100644 index 000000000000..b3530d4f41c8 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classifier Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and forest in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + rfModel = model.stages[2] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index c7730e1bfacd..000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_regressor_example.py b/examples/src/main/python/ml/random_forest_regressor_example.py new file mode 100644 index 000000000000..b59c7c941484 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_regressor_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Regressor Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import RandomForestRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestRegressor(featuresCol="indexedFeatures") + + # Chain indexer and forest in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, rf]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + rfModel = model.stages[1] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 000000000000..b544a1470076 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index a9f29dab2d60..2d6d115d54d0 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -70,7 +70,7 @@ # We may alternatively specify parameters using a parameter map. # paramMap overrides all lr parameters set earlier. - paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"} # Now learn a new model using the new parameters. model2 = lr.fit(training, paramMap) diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py new file mode 100644 index 000000000000..9575d728d815 --- /dev/null +++ b/examples/src/main/python/ml/sql_transformer.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import SQLTransformer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="SQLTransformerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, 1.0, 3.0), + (2, 2.0, 5.0) + ], ["id", "v1", "v2"]) + sqlTrans = SQLTransformer( + statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + sqlTrans.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 000000000000..ae7aa85005bc --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 000000000000..01f94af8ca75 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 000000000000..58a8cb5d56b7 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tf_idf_example.py b/examples/src/main/python/ml/tf_idf_example.py new file mode 100644 index 000000000000..c92313378eec --- /dev/null +++ b/examples/src/main/python/ml/tf_idf_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import HashingTF, IDF, Tokenizer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="TfIdfExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (0, "I wish Java could use case classes"), + (1, "Logistic regression models are neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsData = tokenizer.transform(sentenceData) + hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) + featurizedData = hashingTF.transform(wordsData) + idf = IDF(inputCol="rawFeatures", outputCol="features") + idfModel = idf.fit(featurizedData) + rescaledData = idfModel.transform(featurizedData) + for features_label in rescaledData.select("features", "label").take(3): + print(features_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 000000000000..ce9b225be535 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 000000000000..04f64839f188 --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 000000000000..146f41c1dd90 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + indexedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/word2vec_example.py b/examples/src/main/python/ml/word2vec_example.py new file mode 100644 index 000000000000..53c77feb1014 --- /dev/null +++ b/examples/src/main/python/ml/word2vec_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Word2Vec +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="Word2VecExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Input data: Each row is a bag of words from a sentence or document. + documentDF = sqlContext.createDataFrame([ + ("Hi I heard about Spark".split(" "), ), + ("I wish Java could use case classes".split(" "), ), + ("Logistic regression models are neat".split(" "), ) + ], ["text"]) + # Learn a mapping from words to Vectors. + word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result") + model = word2Vec.fit(documentDF) + result = model.transform(documentDF) + for feature in result.select("result").take(3): + print(feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 000000000000..437acb998acc --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +import sys +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py new file mode 100644 index 000000000000..1b529768b6c6 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeClassificationExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py new file mode 100644 index 000000000000..cf518eac67e8 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeRegressionExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py deleted file mode 100755 index 513ed8fd5145..000000000000 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Decision tree classification and regression using MLlib. - -This example requires NumPy (http://www.numpy.org/). -""" -from __future__ import print_function - -import numpy -import os -import sys - -from operator import add - -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree -from pyspark.mllib.util import MLUtils - - -def getAccuracy(dtModel, data): - """ - Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainCorrect / (0.0 + data.count()) - - -def getMSE(dtModel, data): - """ - Return mean squared error (MSE) of DecisionTreeModel on the given - RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainMSE / (0.0 + data.count()) - - -def reindexClassLabels(data): - """ - Re-index class labels in a dataset to the range {0,...,numClasses-1}. - If all labels in that range already appear at least once, - then the returned RDD is the same one (without a mapping). - Note: If a label simply does not appear in the data, - the index will not include it. - Be aware of this when reindexing subsampled data. - :param data: RDD of LabeledPoint where labels are integer values - denoting labels for a classification problem. - :return: Pair (reindexedData, origToNewLabels) where - reindexedData is an RDD of LabeledPoint with labels in - the range {0,...,numClasses-1}, and - origToNewLabels is a dictionary mapping original labels - to new labels. - """ - # classCounts: class --> # examples in class - classCounts = data.map(lambda x: x.label).countByValue() - numExamples = sum(classCounts.values()) - sortedClasses = sorted(classCounts.keys()) - numClasses = len(classCounts) - # origToNewLabels: class --> index in 0,...,numClasses-1 - if (numClasses < 2): - print("Dataset for classification should have at least 2 classes." - " The given dataset had only %d classes." % numClasses, file=sys.stderr) - exit(1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - - print("numClasses = %d" % numClasses) - print("Per-class example fractions, counts:") - print("Class\tFrac\tCount") - for c in sortedClasses: - frac = classCounts[c] / (numExamples + 0.0) - print("%g\t%g\t%d" % (c, frac, classCounts[c])) - - if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): - return (data, origToNewLabels) - else: - reindexedData = \ - data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) - return (reindexedData, origToNewLabels) - - -def usage(): - print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) - exit(1) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - usage() - sc = SparkContext(appName="PythonDT") - - # Load data. - dataPath = 'data/mllib/sample_libsvm_data.txt' - if len(sys.argv) == 2: - dataPath = sys.argv[1] - if not os.path.isfile(dataPath): - sc.stop() - usage() - points = MLUtils.loadLibSVMFile(sc, dataPath) - - # Re-index class labels if needed. - (reindexedData, origToNewLabels) = reindexClassLabels(points) - numClasses = len(origToNewLabels) - - # Train a classifier. - categoricalFeaturesInfo = {} # no categorical features - model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, - categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print("Trained DecisionTree for classification:") - print(" Model numNodes: %d" % model.numNodes()) - print(" Model depth: %d" % model.depth()) - print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) - if model.numNodes() < 20: - print(model.toDebugString()) - else: - print(model) - - sc.stop() diff --git a/examples/src/main/python/mllib/fpgrowth_example.py b/examples/src/main/python/mllib/fpgrowth_example.py new file mode 100644 index 000000000000..715f5268206c --- /dev/null +++ b/examples/src/main/python/mllib/fpgrowth_example.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.fpm import FPGrowth +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="FPGrowth") + + # $example on$ + data = sc.textFile("data/mllib/sample_fpgrowth.txt") + transactions = data.map(lambda line: line.strip().split(' ')) + model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) + result = model.freqItemsets().collect() + for fi in result: + print(fi) + # $example off$ diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py deleted file mode 100644 index 781bd61c9d2b..000000000000 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Gradient boosted Trees classification and regression using MLlib. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import GradientBoostedTrees -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification ensemble model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ - / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression ensemble model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGradientBoostedTrees") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using GradientBoostedTrees\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using GradientBoostedTrees\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosting_classification_example.py b/examples/src/main/python/mllib/gradient_boosting_classification_example.py new file mode 100644 index 000000000000..a94ea0d582e5 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_classification_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesClassificationExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/gradient_boosting_regression_example.py b/examples/src/main/python/mllib/gradient_boosting_regression_example.py new file mode 100644 index 000000000000..86040799dc1d --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_regression_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesRegressionExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + sameModel = GradientBoostedTreesModel.load(sc, "target/tmp/myGradientBoostingRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/isotonic_regression_example.py b/examples/src/main/python/mllib/isotonic_regression_example.py new file mode 100644 index 000000000000..89dc9f4b6611 --- /dev/null +++ b/examples/src/main/python/mllib/isotonic_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Isotonic Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +import math +from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonIsotonicRegressionExample") + + # $example on$ + data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + + # Create label, feature, weight tuples from input data with weight set to default value 1.0. + parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + + # Split data into training (60%) and test (40%) sets. + training, test = parsedData.randomSplit([0.6, 0.4], 11) + + # Create isotonic regression model from training data. + # Isotonic parameter defaults to true so it is only shown for demonstration + model = IsotonicRegression.train(training) + + # Create tuples of predicted and real labels. + predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) + + # Calculate mean squared error between predicted and real labels. + meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() + print("Mean Squared Error = " + str(meanSquaredError)) + + # Save and load model + model.save(sc, "target/tmp/myIsotonicRegressionModel") + sameModel = IsotonicRegressionModel.load(sc, "target/tmp/myIsotonicRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 000000000000..cd56b3c97c77 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 000000000000..960ade659737 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py new file mode 100644 index 000000000000..f5e120c678fc --- /dev/null +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +NaiveBayes Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint + + +def parseLine(line): + parts = line.split(',') + label = float(parts[0]) + features = Vectors.dense([float(x) for x in parts[1].split(' ')]) + return LabeledPoint(label, features) +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonNaiveBayesExample") + + # $example on$ + data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) + + # Split data aproximately into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=0) + + # Train a naive Bayes model. + model = NaiveBayes.train(training, 1.0) + + # Make prediction and test accuracy. + predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label)) + accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + + # Save and load model + model.save(sc, "target/tmp/myNaiveBayesModel") + sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + # $example off$ diff --git a/examples/src/main/python/mllib/random_forest_classification_example.py b/examples/src/main/python/mllib/random_forest_classification_example.py new file mode 100644 index 000000000000..324ba50625d2 --- /dev/null +++ b/examples/src/main/python/mllib/random_forest_classification_example.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestClassificationExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + # Note: Use larger numTrees in practice. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + numTrees=3, featureSubsetStrategy="auto", + impurity='gini', maxDepth=4, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification forest model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_regression_example.py old mode 100755 new mode 100644 similarity index 51% rename from examples/src/main/python/mllib/random_forest_example.py rename to examples/src/main/python/mllib/random_forest_regression_example.py index 4cfdad868c66..f7aa6114eceb --- a/examples/src/main/python/mllib/random_forest_example.py +++ b/examples/src/main/python/mllib/random_forest_regression_example.py @@ -16,42 +16,26 @@ # """ -Random Forest classification and regression using MLlib. - -Note: This example illustrates binary classification. - For information on multiclass classification, please refer to the decision_tree_runner.py - example. +Random Forest Regression Example. """ from __future__ import print_function import sys -from pyspark.context import SparkContext -from pyspark.mllib.tree import RandomForest +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils +# $example off$ +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestRegressionExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) -def testClassification(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainClassifier(trainingData, numClasses=2, - categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification forest model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): # Train a RandomForest model. # Empty categoricalFeaturesInfo indicates all features are continuous. # Note: Use larger numTrees in practice. @@ -63,28 +47,13 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ - .sum() / float(testData.count()) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using RandomForest\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using RandomForest\n') - testRegression(trainingData, testData) - - sc.stop() + # Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 000000000000..327791966c90 --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean sqaured error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py new file mode 100644 index 000000000000..615db0749b18 --- /dev/null +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Collaborative Filtering Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext + +# $example on$ +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonCollaborativeFilteringExample") + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/als/test.data") + ratings = data.map(lambda l: l.split(','))\ + .map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) + + # Build the recommendation model using Alternating Least Squares + rank = 10 + numIterations = 10 + model = ALS.train(ratings, rank, numIterations) + + # Evaluate the model on training data + testdata = ratings.map(lambda p: (p[0], p[1])) + predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) + ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) + MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 000000000000..a3a83aafd7a1 --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 6ef188a220c5..ea20678b9aca 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -23,8 +23,8 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ @@ -37,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 091b64d8c4af..d75bc6daac13 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -23,8 +23,9 @@ https://flume.apache.org/documentation.html and then run the example - `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ - spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + `$ bin/spark-submit --jars \ + external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ + examples/src/main/python/streaming/flume_wordcount.py \ localhost 12345 """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index b178e7899b5e..8d697f620f46 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -23,8 +23,9 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 000000000000..abf9c0e21d30 --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars \ + external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ + examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py index dcd6a0fc6ff9..b3808907f74a 100644 --- a/examples/src/main/python/streaming/queue_stream.py +++ b/examples/src/main/python/streaming/queue_stream.py @@ -36,8 +36,8 @@ # Create the queue through which RDDs can be pushed to # a QueueInputDStream rddQueue = [] - for i in xrange(5): - rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + for i in range(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)] # Create the QueueInputDStream and use it do some processing inputStream = ssc.queueStream(rddQueue) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c42..f8bbc659c2ea 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 53b817144f6a..62f60e57eebe 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -35,7 +35,7 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- jsonFile(sqlContext, path) +peopleDF <- read.json(sqlContext, path) printSchema(peopleDF) # Register this DataFrame as a table. diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R new file mode 100644 index 000000000000..a0c903939cbb --- /dev/null +++ b/examples/src/main/r/ml.R @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/sparkR examples/src/main/r/ml.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkContext and SQLContext +sc <- sparkR.init(appName="SparkR-ML-example") +sqlContext <- sparkRSQL.init(sc) + +# Train GLM of family 'gaussian' +training1 <- suppressWarnings(createDataFrame(sqlContext, iris)) +test1 <- training1 +model1 <- glm(Sepal_Length ~ Sepal_Width + Species, training1, family = "gaussian") + +# Model summary +summary(model1) + +# Prediction +predictions1 <- predict(model1, test1) +head(select(predictions1, "Sepal_Length", "prediction")) + +# Train GLM of family 'binomial' +training2 <- filter(training1, training1$Species != "setosa") +test2 <- training2 +model2 <- glm(Species ~ Sepal_Length + Sepal_Width, data = training2, family = "binomial") + +# Model summary +summary(model2) + +# Prediction (Currently the output of prediction for binomial GLM is the indexed label, +# we need to transform back to the original string label later) +predictions2 <- predict(model2, test2) +head(select(predictions2, "Species", "prediction")) + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 36832f51d2ad..d1b9b8d398dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,13 +16,11 @@ */ // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.collection.immutable.Map +import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat @@ -32,7 +30,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* @@ -85,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -121,12 +119,9 @@ object CassandraCQLTest { val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { - val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) - val outKey: java.util.Map[String, ByteBuffer] = outColFamKey - var outColFamVal = new ListBuffer[ByteBuffer] - outColFamVal += ByteBufferUtil.bytes(saleCount) - val outVal: java.util.List[ByteBuffer] = outColFamVal - (outKey, outVal) + val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) + val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) + (outKey, outVal) } } @@ -142,3 +137,4 @@ object CassandraCQLTest { } } // scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 96ef3e198e38..1e679bfb5534 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,13 +16,13 @@ */ // scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer +import java.util.Arrays import java.util.SortedMap -import scala.collection.JavaConversions._ - import org.apache.cassandra.db.IColumn import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper @@ -32,7 +32,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra @@ -118,7 +117,7 @@ object CassandraTest { val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + val mutations = Arrays.asList(new Mutation(), new Mutation()) mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) @@ -132,6 +131,7 @@ object CassandraTest { } } // scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index c42df2b8845d..bec61f3cd429 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils @@ -36,10 +36,10 @@ object DriverSubmissionTest { val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") - env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala new file mode 100644 index 000000000000..f4b3613ccb94 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.regression.AFTSurvivalRegression +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +/** + * An example for AFTSurvivalRegression. + */ +object AFTSurvivalRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val training = sqlContext.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226)) + )).toDF("label", "censor", "features") + val quantileProbabilities = Array(0.3, 0.6) + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + + val model = aft.fit(training) + + // Print the coefficients, intercept and scale parameter for AFT survival regression + println(s"Coefficients: ${model.coefficients} Intercept: " + + s"${model.intercept} Scale: ${model.scale}") + model.transform(training).show(false) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 000000000000..e724aa587294 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 000000000000..7c75e3d72b47 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala new file mode 100644 index 000000000000..a8d2bc4907e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ChiSqSelectorExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Seq( + (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + ) + + val df = sc.parallelize(data).toDF("id", "features", "clicked") + + val selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala new file mode 100644 index 000000000000..ba916f66c4c0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object CountVectorizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("CounterVectorizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, Array("a", "b", "c")), + (1, Array("a", "b", "b", "c", "a")) + )).toDF("id", "words") + + // fit a CountVectorizerModel from the corpus + val cvModel: CountVectorizerModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) + .setMinDF(2) + .fit(df) + + // alternatively, define CountVectorizerModel with a-priori vocabulary + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features") + + cvModel.transform(df).select("features").show() + // $example off$ + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 000000000000..314c2c28a2a1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala new file mode 100644 index 000000000000..0a477abae567 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * {{{ + * ./bin/run-example ml.DataFrameExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DataFrameExample { + + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DataFrameExample") { + head("DataFrameExample: an example app using DataFrame for ML.") + opt[String]("input") + .text(s"input path to dataframe") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // Load input data + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + // Save the records in a parquet file. + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataframe").toString + println(s"Saving to $outputDir as Parquet file.") + df.write.parquet(outputDir) + + // Load the records back. + println(s"Loading Parquet file with UDT from $outputDir.") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala new file mode 100644 index 000000000000..db024b5cad93 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ + +object DecisionTreeClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] + println("Learned classification tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index f28671f7869f..c4e98dfaca6c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -32,10 +32,7 @@ import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTree import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{SQLContext, DataFrame} @@ -138,15 +135,18 @@ object DecisionTreeExample { /** Load a dataset from the given path, using the given format */ private[ml] def loadData( - sc: SparkContext, + sqlContext: SQLContext, path: String, format: String, - expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + expectedNumFeatures: Option[Int] = None): DataFrame = { + import sqlContext.implicits._ + format match { - case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "dense" => MLUtils.loadLabeledPoints(sqlContext.sparkContext, path).toDF() case "libsvm" => expectedNumFeatures match { - case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) - case None => MLUtils.loadLibSVMFile(sc, path) + case Some(numFeatures) => sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(path) + case None => sqlContext.read.format("libsvm").load(path) } case _ => throw new IllegalArgumentException(s"Bad data format: $format") } @@ -169,36 +169,22 @@ object DecisionTreeExample { algo: String, fracTest: Double): (DataFrame, DataFrame) = { val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // Load training data - val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + val origExamples: DataFrame = loadData(sqlContext, input, dataFormat) // Load or create test set - val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + val dataframes: Array[DataFrame] = if (testInput != "") { // Load testInput. - val numFeatures = origExamples.take(1)(0).features.size - val origTestExamples: RDD[LabeledPoint] = - loadData(sc, testInput, dataFormat, Some(numFeatures)) + val numFeatures = origExamples.first().getAs[Vector](1).size + val origTestExamples: DataFrame = + loadData(sqlContext, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) } - // For classification, convert labels to Strings since we will index them later with - // StringIndexer. - def labelsToStrings(data: DataFrame): DataFrame = { - algo.toLowerCase match { - case "classification" => - data.withColumn("labelString", data("label").cast(StringType)) - case "regression" => - data - case _ => - throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - } - val dataframes = splits.map(_.toDF()).map(labelsToStrings) val training = dataframes(0).cache() val test = dataframes(1).cache() @@ -230,7 +216,7 @@ object DecisionTreeExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala new file mode 100644 index 000000000000..ad01f55df72b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +// $example off$ +object DecisionTreeRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Here, we treat features with > 4 distinct values as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] + println("Learned regression tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe56..c1f63c6a1dce 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -75,7 +75,7 @@ object DeveloperApiExample { prediction }.sum assert(sumPredictions == 0.0, - "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") sc.stop() } @@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String) // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size - val weights = Vectors.zeros(numFeatures) // Learning would happen here. + val coefficients = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(uid, weights).setParent(this) + new MyLogisticRegressionModel(uid, coefficients).setParent(this) } override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String) */ private class MyLogisticRegressionModel( override val uid: String, - val weights: Vector) + val coefficients: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel( * confidence for that label. */ override protected def predictRaw(features: Vector): Vector = { - val margin = BLAS.dot(features, weights) + val margin = BLAS.dot(features, coefficients) // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). Vectors.dense(-margin, margin) @@ -172,6 +172,9 @@ private class MyLogisticRegressionModel( /** Number of classes the label can take. 2 indicates binary classification. */ override val numClasses: Int = 2 + /** Number of features the model was trained on. */ + override val numFeatures: Int = coefficients.size + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. @@ -179,7 +182,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent) } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala new file mode 100644 index 000000000000..872de51dc75d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index f4a15f806ea8..6b0be0f34e19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -153,7 +153,7 @@ object GBTExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala new file mode 100644 index 000000000000..474af7db4b49 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ + +object GradientBoostedTreeClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] + println("Learned classification GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala new file mode 100644 index 000000000000..da1cd9c2ce52 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +// $example off$ + +object GradientBoostedTreeRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Chain indexer and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, gbt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] + println("Learned regression GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala new file mode 100644 index 000000000000..52537e5bb568 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +// $example off$ + +object IndexToStringExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("IndexToStringExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory") + + val converted = converter.transform(indexed) + converted.select("id", "originalCategory").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 5ce38462d118..af90652b55a1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -17,57 +17,54 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example demonstrating a k-means clustering. * Run with * {{{ - * bin/run-example ml.KMeansExample + * bin/run-example ml.KMeansExample * }}} */ object KMeansExample { - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - System.err.println("Usage: ml.KMeansExample ") - // scalastyle:on println - System.exit(1) - } - val input = args(0) - val k = args(1).toInt - // Creates a Spark context and a SQL context val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // $example on$ + // Crates a DataFrame + val dataset: DataFrame = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(0.0, 0.0, 0.0)), + (2, Vectors.dense(0.1, 0.1, 0.1)), + (3, Vectors.dense(0.2, 0.2, 0.2)), + (4, Vectors.dense(9.0, 9.0, 9.0)), + (5, Vectors.dense(9.1, 9.1, 9.1)), + (6, Vectors.dense(9.2, 9.2, 9.2)) + )).toDF("id", "features") // Trains a k-means model val kmeans = new KMeans() - .setK(k) - .setFeaturesCol(FEATURES_COL) + .setK(2) + .setFeaturesCol("features") + .setPredictionCol("prediction") val model = kmeans.fit(dataset) // Shows the result - // scalastyle:off println println("Final Centers: ") model.clusterCenters.foreach(println) - // scalastyle:on println + // $example off$ sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 000000000000..419ce3d87a6a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b73299fb12d3..50998c94de3d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -131,7 +131,7 @@ object LinearRegressionExample { println(s"Training time: $elapsedTime seconds") // Print the weights and intercept for linear regression. - println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + println(s"Weights: ${lirModel.coefficients} Intercept: ${lirModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala new file mode 100644 index 000000000000..22c824cea84d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.regression.LinearRegression +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object LinearRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") + + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for linear regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + + // Summarize the model over the training set and print out some metrics + val trainingSummary = lrModel.summary + println(s"numIterations: ${trainingSummary.totalIterations}") + println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") + trainingSummary.residuals.show() + println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") + println(s"r2: ${trainingSummary.r2}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 7682557127b5..a380c90662a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -125,7 +125,7 @@ object LogisticRegressionExample { val stages = new mutable.ArrayBuffer[PipelineStage]() val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol("indexedLabel") stages += labelIndexer @@ -136,6 +136,7 @@ object LogisticRegressionExample { .setElasticNetParam(params.elasticNetParam) .setMaxIter(params.maxIter) .setTol(params.tol) + .setFitIntercept(params.fitIntercept) stages += lor val pipeline = new Pipeline().setStages(stages.toArray) @@ -148,7 +149,7 @@ object LogisticRegressionExample { val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] // Print the weights and intercept for logistic regression. - println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + println(s"Weights: ${lorModel.coefficients} Intercept: ${lorModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala new file mode 100644 index 000000000000..4c420421b670 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.max +import org.apache.spark.{SparkConf, SparkContext} + +object LogisticRegressionSummaryExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionSummaryExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + import sqlCtx.implicits._ + + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration. + val objectiveHistory = trainingSummary.objectiveHistory + objectiveHistory.foreach(loss => println(loss)) + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a + // binary classification problem. + val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + val roc = binarySummary.roc + roc.show() + println(binarySummary.areaUnderROC) + + // Set the model threshold to maximize F-Measure + val fMeasure = binarySummary.fMeasureByThreshold + val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) + val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) + .select("threshold").head().getDouble(0) + lrModel.setThreshold(bestThreshold) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala new file mode 100644 index 000000000000..9ee995b52c90 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object LogisticRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for logistic regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 000000000000..fb7f28c9886b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cd411397a4b9..02ed746954f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -50,7 +50,7 @@ object MovieLensALS { def parseMovie(str: String): Movie = { val fields = str.split("::") assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("|")) + Movie(fields(0).toInt, fields(1), fields(2).split("\\|")) } } @@ -76,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala new file mode 100644 index 000000000000..9c98076bd24b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.sql.SQLContext +// $example on$ +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ + +/** + * An example for Multilayer Perceptron Classification. + */ +object MultilayerPerceptronClassifierExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultilayerPerceptronClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + // Split the data into train and test + val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) + val train = splits(0) + val test = splits(1) + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + val layers = Array[Int](4, 5, 4, 3) + // create the trainer and set its parameters + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) + // train the model + val model = trainer.fit(train) + // compute precision on the test set + val result = model.transform(test) + val predictionAndLabels = result.select("prediction", "label") + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") + println("Precision:" + evaluator.evaluate(predictionAndLabels)) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 000000000000..8a85f71b56f3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 000000000000..1990b55e8c5e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + lInfNormData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 000000000000..66602e211850 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index bab31f585b0e..b46faea5713f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -23,13 +23,14 @@ import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} import scopt.OptionParser import org.apache.spark.{SparkContext, SparkConf} +// $example on$ import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.DataFrame +// $example off$ import org.apache.spark.sql.SQLContext /** @@ -111,24 +112,25 @@ object OneVsRestExample { private def run(params: Params) { val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") val sc = new SparkContext(conf) - val inputData = MLUtils.loadLibSVMFile(sc, params.input) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + // $example on$ + val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { case Some(t) => { // compute the number of features in the training set. - val numFeatures = inputData.first().features.size - val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) - Array[RDD[LabeledPoint]](inputData, testData) + val numFeatures = inputData.first().getAs[Vector](1).size + val testData = sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(t) + Array[DataFrame](inputData, testData) } case None => { val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) } } - val Array(train, test) = data.map(_.toDF().cache()) + val Array(train, test) = data.map(_.cache()) // instantiate the base classifier val classifier = new LogisticRegression() @@ -173,6 +175,7 @@ object OneVsRestExample { println("label\tfpr") println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 000000000000..4c806f71a32c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 000000000000..39fb79af3576 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala new file mode 100644 index 000000000000..8f29b7eaa6d2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.QuantileDiscretizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object QuantileDiscretizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("QuantileDiscretizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) + val df = sc.parallelize(data).toDF("id", "hour") + + val discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3) + + val result = discretizer.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 000000000000..286866edea50 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala new file mode 100644 index 000000000000..e79176ca6ca1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ + +object RandomForestClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setNumTrees(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] + println("Learned classification forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 109178f4137b..7a00d99dfe53 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -159,7 +159,7 @@ object RandomForestExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala new file mode 100644 index 000000000000..acec1437a1af --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +// $example off$ + +object RandomForestRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, rf)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] + println("Learned regression forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala new file mode 100644 index 000000000000..014abd1fdbc6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.SQLTransformer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object SQLTransformerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SQLTransformerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + + sqlTrans.transform(df).show() + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 58d7b67674ff..f4d1fe57856a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -70,7 +70,7 @@ object SimpleParamsExample { // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 000000000000..e0a41e383a7e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 000000000000..655ffce08d3a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 000000000000..9fa494cd2473 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala new file mode 100644 index 000000000000..40c33e4e7d44 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TfIdfExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("TfIdfExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceData = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (0, "I wish Java could use case classes"), + (1, "Logistic regression models are neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val wordsData = tokenizer.transform(sentenceData) + val hashingTF = new HashingTF() + .setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) + val featurizedData = hashingTF.transform(wordsData) + val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") + val idfModel = idf.fit(featurizedData) + val rescaledData = idfModel.transform(featurizedData) + rescaledData.select("features", "label").take(3).foreach(println) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 000000000000..01e0d1388a2f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala new file mode 100644 index 000000000000..cd1b0e9358be --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on [[SimpleParamsExample]] using linear regression. + * Run with + * {{{ + * bin/run-example ml.TrainValidationSplitExample + * }}} + */ +object TrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // Prepare training and test data. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 000000000000..d527924419f8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 000000000000..685891c164e7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + indexedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 000000000000..04f19829eff8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala new file mode 100644 index 000000000000..631ab4c8efa0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Word2Vec +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object Word2VecExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("Word2Vec example") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + val documentDF = sqlContext.createDataFrame(Seq( + "Hi I heard about Spark".split(" "), + "I wish Java could use case classes".split(" "), + "Logistic regression models are neat".split(" ") + ).map(Tuple1.apply)).toDF("text") + + // Learn a mapping from words to Vectors. + val word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0) + val model = word2Vec.fit(documentDF) + val result = model.transform(documentDF) + result.select("result").take(3).foreach(println) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala new file mode 100644 index 000000000000..ca22ddafc3c4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.AssociationRules +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object AssociationRulesExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("AssociationRulesExample") + val sc = new SparkContext(conf) + + // $example on$ + val freqItemsets = sc.parallelize(Seq( + new FreqItemset(Array("a"), 15L), + new FreqItemset(Array("b"), 35L), + new FreqItemset(Array("a", "b"), 12L) + )) + + val ar = new AssociationRules() + .setMinConfidence(0.8) + val results = ar.run(freqItemsets) + + results.collect().foreach { rule => + println("[" + rule.antecedent.mkString(",") + + "=>" + + rule.consequent.mkString(",") + "]," + rule.confidence) + } + // $example off$ + } + +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 000000000000..13a37827ab93 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala new file mode 100644 index 000000000000..3a596cccb87d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +// scalastyle:off println +// $example on$ +import org.apache.spark.mllib.clustering.BisectingKMeans +import org.apache.spark.mllib.linalg.{Vector, Vectors} +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example demonstrating a bisecting k-means clustering in spark.mllib. + * + * Run with + * {{{ + * bin/run-example mllib.BisectingKMeansExample + * }}} + */ +object BisectingKMeansExample { + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample") + val sc = new SparkContext(sparkConf) + + // $example on$ + // Loads and parses data + def parse(line: String): Vector = Vectors.dense(line.split(" ").map(_.toDouble)) + val data = sc.textFile("data/mllib/kmeans_data.txt").map(parse).cache() + + // Clustering the data into 6 clusters by BisectingKMeans. + val bkm = new BisectingKMeans().setK(6) + val model = bkm.run(data) + + // Show the compute cost and the cluster centers + println(s"Compute Cost: ${model.computeCost(data)}") + model.clusterCenters.zipWithIndex.foreach { case (center, idx) => + println(s"Cluster Center ${idx}: ${center}") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala deleted file mode 100644 index dc13f82488af..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import java.io.File - -import com.google.common.io.Files -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} - -/** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with - * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DatasetExample { - - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using DataFrame as a Dataset for ML.") - opt[String]("input") - .text(s"input path to dataset") - .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) - checkConfig { params => - success - } - } - - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DatasetExample with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions - - // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } - val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") - - val tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString - println(s"Saving to $outputDir as Parquet file.") - df.write.parquet(outputDir) - - println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") - - sc.stop() - } - -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala new file mode 100644 index 000000000000..d427bbadaa0c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeClassificationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "gini" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala new file mode 100644 index 000000000000..fb05e7d9c506 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "variance" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 57ffe3dd2524..cc6bce3cb7c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -100,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala new file mode 100644 index 000000000000..139e1f909bdc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + val boostingStrategy = BoostingStrategy.defaultParams("Classification") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.numClasses = 2 + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + // $example off$ + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala new file mode 100644 index 000000000000..3dc86da8e4d2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + val boostingStrategy = BoostingStrategy.defaultParams("Regression") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala new file mode 100644 index 000000000000..52ac9ae7dd2d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object IsotonicRegressionExample { + + def main(args: Array[String]) : Unit = { + + val conf = new SparkConf().setAppName("IsotonicRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + + // Create label, feature, weight tuples from input data with weight set to default value 1.0. + val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) + } + + // Split data into training (60%) and test (40%) sets. + val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0) + val test = splits(1) + + // Create isotonic regression model from training data. + // Isotonic parameter defaults to true so it is only shown for demonstration + val model = new IsotonicRegression().setIsotonic(true).run(training) + + // Create tuples of predicted and real labels. + val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) + } + + // Calculate mean squared error between predicted and real labels. + val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() + println("Mean Squared Error = " + meanSquaredError) + + // Save and load model + model.save(sc, "target/tmp/myIsotonicRegressionModel") + val sameModel = IsotonicRegressionModel.load(sc, "target/tmp/myIsotonicRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala new file mode 100644 index 000000000000..61d2e7715f53 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object LBFGSExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("LBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + val numFeatures = data.take(1)(0).features.size + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + + // Append 1 into the training data as intercept. + val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() + + val test = splits(1) + + // Run training algorithm to build the model + val numCorrections = 10 + val convergenceTol = 1e-4 + val maxNumIterations = 20 + val regParam = 0.1 + val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) + + val (weightsWithIntercept, loss) = LBFGS.runLBFGS( + training, + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept) + + val model = new LogisticRegressionModel( + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), + weightsWithIntercept(weightsWithIntercept.size - 1)) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Loss of each step in training process") + loss.foreach(println) + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 75b0f69cf91a..70010b05e434 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,19 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import java.text.BreakIterator - -import scala.collection.mutable - import scopt.OptionParser import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -192,115 +189,45 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, // this can result in a large number of small partitions, which can degrade performance. // In this case, consider using coalesce() to create fewer, larger partitions. - val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) - - // Split text into words - val tokenizer = new SimpleTokenizer(sc, stopwordFile) - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => - id -> tokenizer.getWords(text) - } - tokenized.cache() - - // Counts words: RDD[(word, wordCount)] - val wordCounts: RDD[(String, Long)] = tokenized - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - wordCounts.cache() - val fullVocabSize = wordCounts.count() - // Select vocab - // (vocab: Map[word -> id], total tokens after selecting vocab) - val (vocab: Map[String, Int], selectedTokenCount: Long) = { - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) - } - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) - } - - val documents = tokenized.map { case (id, tokens) => - // Filter tokens by vocabulary, and create word count vector representation of document. - val wc = new mutable.HashMap[Int, Int]() - tokens.foreach { term => - if (vocab.contains(term)) { - val termIndex = vocab(term) - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 - } - } - val indices = wc.keys.toArray.sorted - val values = indices.map(i => wc(i).toDouble) - - val sb = Vectors.sparse(vocab.size, indices, values) - (id, sb) - } - - val vocabArray = new Array[String](vocab.size) - vocab.foreach { case (term, i) => vocabArray(i) = term } - - (documents, vocabArray, selectedTokenCount) - } -} - -/** - * Simple Tokenizer. - * - * TODO: Formalize the interface, and make this a public class in mllib.feature - */ -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { - - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { - Set.empty[String] - } else { - val stopwordText = sc.textFile(stopwordFile).collect() - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet - } - - // Matches sequences of Unicode letters - private val allWordRegex = "^(\\p{L}*)$".r - - // Ignore words shorter than this length. - private val minWordLength = 3 - - def getWords(text: String): IndexedSeq[String] = { - - val words = new mutable.ArrayBuffer[String]() - - // Use Java BreakIterator to tokenize text into words. - val wb = BreakIterator.getWordInstance - wb.setText(text) - - // current,end index start,end of each word - var current = wb.first() - var end = wb.next() - while (end != BreakIterator.DONE) { - // Convert to lowercase - val word: String = text.substring(current, end).toLowerCase - // Remove short words and strings that aren't only letters - word match { - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => - words += w - case _ => - } - - current = end - try { - end = wb.next() - } catch { - case e: Exception => - // Ignore remaining text in line. - // This is a known bug in BreakIterator (for some Java versions), - // which fails when it sees certain characters. - end = BreakIterator.DONE - } + val df = sc.textFile(paths.mkString(",")).toDF("docs") + val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) { + Array.empty[String] + } else { + val stopWordText = sc.textFile(stopwordFile).collect() + stopWordText.flatMap(_.stripMargin.split("\\s+")) } - words + val tokenizer = new RegexTokenizer() + .setInputCol("docs") + .setOutputCol("rawTokens") + val stopWordsRemover = new StopWordsRemover() + .setInputCol("rawTokens") + .setOutputCol("tokens") + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(vocabSize) + .setInputCol("tokens") + .setOutputCol("features") + + val pipeline = new Pipeline() + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) + + val model = pipeline.fit(df) + val documents = model.transform(df) + .select("features") + .map { case Row(features: Vector) => features } + .zipWithIndex() + .map(_.swap) + + (documents, + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary + documents.map(_._2.numActives).sum().toLong) // total token count } - } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index e43a6f2864c7..69691ae297f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 000000000000..4503c15360ad --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 000000000000..090444924598 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala new file mode 100644 index 000000000000..a7a47c2a3556 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object NaiveBayesExample { + + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("NaiveBayesExample") + val sc = new SparkContext(conf) + // $example on$ + val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") + val parsedData = data.map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + } + + // Split data into training (60%) and test (40%). + val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0) + val test = splits(1) + + val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") + + val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) + val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + + // Save and load model + model.save(sc, "target/tmp/myNaiveBayesModel") + val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + // $example off$ + } +} + +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala new file mode 100644 index 000000000000..d237232c430c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.PrefixSpan +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object PrefixSpanExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("PrefixSpanExample") + val sc = new SparkContext(conf) + + // $example on$ + val sequences = sc.parallelize(Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6)) + ), 2).cache() + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(sequences) + model.freqSequences.collect().foreach { freqSequence => + println( + freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + + ", " + freqSequence.freq) + } + // $example off$ + } +} +// scalastyle:off println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala new file mode 100644 index 000000000000..5e55abd5121c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "gini" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala new file mode 100644 index 000000000000..a54fb3ab7e37 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "variance" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + // $example off$ + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 000000000000..cffa03d5cc9f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala new file mode 100644 index 000000000000..64e460246544 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel +import org.apache.spark.mllib.recommendation.Rating +// $example off$ + +object RecommendationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("CollaborativeFilteringExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/als/test.data") + val ratings = data.map(_.split(',') match { case Array(user, item, rate) => + Rating(user.toInt, item.toInt, rate.toDouble) + }) + + // Build the recommendation model using ALS + val rank = 10 + val numIterations = 10 + val model = ALS.train(ratings, rank, numIterations, 0.01) + + // Evaluate the model on rating data + val usersProducts = ratings.map { case Rating(user, product, rate) => + (user, product) + } + val predictions = + model.predict(usersProducts).map { case Rating(user, product, rate) => + ((user, product), rate) + } + val ratesAndPreds = ratings.map { case Rating(user, product, rate) => + ((user, product), rate) + }.join(predictions) + val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => + val err = (r1 - r2) + err * err + }.mean() + println("Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + val sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 000000000000..47d44532521c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RegressionMetricsExample { + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala new file mode 100644 index 000000000000..b4e06afa7410 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.rdd.RDD +// $example off$ + +import org.apache.spark.{SparkContext, SparkConf} + +object SimpleFPGrowth { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("SimpleFPGrowth") + val sc = new SparkContext(conf) + + // $example on$ + val data = sc.textFile("data/mllib/sample_fpgrowth.txt") + + val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) + + val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) + val model = fpg.run(transactions) + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + val minConfidence = 0.8 + model.generateAssociationRules(minConfidence).collect().foreach { rule => + println( + rule.antecedent.mkString("[", ",", "]") + + " => " + rule.consequent .mkString("[", ",", "]") + + ", " + rule.confidence) + } + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 000000000000..49f5df39443e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest} +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + // $example on$ + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => BinarySample(label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + // $example off$ + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 3ebb112fc069..cf12c98b4af6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.pythonconverters import java.util.{Collection => JCollection, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper @@ -58,7 +58,7 @@ object AvroConversionUtil extends Serializable { val map = new java.util.HashMap[String, Any] obj match { case record: IndexedRecord => - record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + record.getSchema.getFields.asScala.zipWithIndex.foreach { case (f, i) => map.put(f.name, fromAvro(record.get(i), f.schema)) } case other => throw new SparkException( @@ -68,9 +68,9 @@ object AvroConversionUtil extends Serializable { } def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { - obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + obj.asInstanceOf[JMap[_, _]].asScala.map { case (key, value) => (key.toString, fromAvro(value, schema.getValueType)) - } + }.asJava } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable { def unpackBytes(obj: Any): Array[Byte] = { val bytes: Array[Byte] = obj match { - case buf: java.nio.ByteBuffer => buf.array() + case buf: java.nio.ByteBuffer => + val arr = new Array[Byte](buf.remaining()) + buf.get(arr) + arr case arr: Array[Byte] => arr case other => throw new SparkException( s"Unknown BYTES type ${other.getClass.getName}") @@ -91,17 +94,17 @@ object AvroConversionUtil extends Serializable { def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { case c: JCollection[_] => - c.map(fromAvro(_, schema.getElementType)) + c.asScala.map(fromAvro(_, schema.getElementType)).toSeq.asJava case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => - arr.toSeq + arr.toSeq.asJava.asInstanceOf[JCollection[Any]] case arr: Array[_] => - arr.map(fromAvro(_, schema.getElementType)).toSeq + arr.map(fromAvro(_, schema.getElementType)).toSeq.asJava case other => throw new SparkException( s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { - schema.getTypes.toList match { + schema.getTypes.asScala.toList match { case List(s) => fromAvro(obj, s) case List(n, s) if n.getType == NULL => fromAvro(obj, s) case List(s, n) if n.getType == NULL => fromAvro(obj, s) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 83feb5703b90..00ce47af4813 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples.pythonconverters -import org.apache.spark.api.python.Converter import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions._ +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra @@ -30,7 +32,7 @@ import collection.JavaConversions._ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { override def convert(obj: Any): java.util.Map[String, Int] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + result.asScala.mapValues(ByteBufferUtil.toInt).asJava } } @@ -41,7 +43,7 @@ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int] class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { override def convert(obj: Any): java.util.Map[String, String] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + result.asScala.mapValues(ByteBufferUtil.string).asJava } } @@ -52,7 +54,7 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { val input = obj.asInstanceOf[java.util.Map[String, Int]] - mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + input.asScala.mapValues(ByteBufferUtil.bytes).asJava } } @@ -63,6 +65,6 @@ class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, By class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { override def convert(obj: Any): java.util.List[ByteBuffer] = { val input = obj.asInstanceOf[java.util.List[String]] - seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + input.asScala.map(ByteBufferUtil.bytes).asJava } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 90d48a64106c..0a25ee7ae56f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.pythonconverters -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter @@ -33,7 +33,6 @@ import org.apache.hadoop.hbase.CellUtil */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { - import collection.JavaConverters._ val result = obj.asInstanceOf[Result] val output = result.listCells.asScala.map(cell => Map( @@ -77,7 +76,7 @@ class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBy */ class StringListToPutConverter extends Converter[Any, Put] { override def convert(obj: Any): Put = { - val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val output = obj.asInstanceOf[java.util.ArrayList[String]].asScala.map(Bytes.toBytes).toArray val put = new Put(output(0)) put.add(output(1), output(2), output(3)) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f..2dce1820d973 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,24 +44,12 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial state RDD for mapWithState operation val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the @@ -70,10 +58,17 @@ object StatefulNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) - // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + // Update the cumulative count using mapWithState + // This will give a DStream made of state (which is the cumulative count of the words) + val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOption.getOrElse(0) + val output = (word, sum) + state.update(sum) + output + } + + val stateDstream = wordDstream.mapWithState( + StateSpec.function(mappingFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index bea7a47cb285..2fcccb22dddf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -51,8 +51,8 @@ object PageView extends Serializable { */ // scalastyle:on object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, "http://foo.com/contact" -> .1) val httpStatus = Map(200 -> .95, 404 -> .05) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index ec7d39da8b2e..723616817f6a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.examples.streaming.StreamingExamples // scalastyle:off @@ -87,8 +86,10 @@ object PageViewStream { .map("Unique active users: " + _) // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + val userList = ssc.sparkContext.parallelize(Seq( + 1 -> "Patrick Wendell", + 2 -> "Reynold Xin", + 3 -> "Matei Zaharia")) metric match { case "pageCounts" => pageCounts.print() @@ -106,6 +107,7 @@ object PageViewStream { } ssc.start() + ssc.awaitTermination() } } // scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 13189595d1d6..dceedcf23ed5 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -68,6 +68,11 @@ commons-codec provided + + commons-lang + commons-lang + provided + commons-net commons-net @@ -88,6 +93,12 @@ avro-ipc provided + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + org.scala-lang scala-library @@ -104,7 +115,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar *:* diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0664cfb2021e..75113ff753e7 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -90,6 +90,10 @@ 3.4.0.Final test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index fa43629d4977..941fde45cd7b 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -20,7 +20,7 @@ import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -36,11 +36,11 @@ import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory // Spark core main, which has too many dependencies to require here manually. // For this reason, we continue to use FunSuite and ignore the scalastyle checks // that fail if this is detected. -//scalastyle:off +// scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { -//scalastyle:on +// scalastyle:on val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -166,7 +166,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("capacity", channelCapacity.toString) channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) + channelContext.putAll(overrides.asJava) channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e..57f83607365d 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -67,14 +67,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 65c49c131518..48df27b26867 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.io.{ObjectOutput, ObjectInput} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils import org.apache.spark.Logging @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k, v) <- headers) { + for ((k, v) <- headers.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index 88cc2aa3bf02..b9d4e762ca05 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -155,7 +154,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) var j = 0 while (j < events.size()) { - val event = events(j) + val event = events.get(j) val sparkFlumeEvent = new SparkFlumeEvent() sparkFlumeEvent.event.setBody(event.getBody) sparkFlumeEvent.event.setHeaders(event.getHeaders) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 1e32a365a1ee..2b9116eb3c79 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -22,7 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.flume.source.avro.AvroSourceProtocol @@ -43,7 +43,7 @@ import org.jboss.netty.handler.codec.compression._ private[streaming] class FlumeInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel, @@ -93,13 +93,13 @@ class SparkFlumeEvent() extends Externalizable { /* Serialize to bytes. */ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) + val body = event.getBody + out.writeInt(body.remaining()) + Utils.writeByteBuffer(body, out) val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { + for ((k, v) <- event.getHeaders.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) @@ -127,8 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) + events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 583e7dca317a..6737750c3d63 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -18,9 +18,9 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{LinkedBlockingQueue, Executors} +import java.util.concurrent.{Executors, LinkedBlockingQueue, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -46,7 +46,7 @@ import org.apache.spark.streaming.flume.sink._ * @tparam T Class type of the object of this stream */ private[streaming] class FlumePollingInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, val addresses: Seq[InetSocketAddress], val maxBatchSize: Int, val parallelism: Int, @@ -93,10 +93,12 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") - receiverExecutor.shutdownNow() - connections.foreach(connection => { - connection.transceiver.close() - }) + receiverExecutor.shutdown() + // Wait upto a minute for the threads to die + if (!receiverExecutor.awaitTermination(60, TimeUnit.SECONDS)) { + receiverExecutor.shutdownNow() + } + connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index 9d9c3b189415..fe5dcc8e4b9d 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -20,8 +20,9 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer import java.util.{List => JList} +import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import org.apache.avro.ipc.NettyTransceiver @@ -62,10 +63,10 @@ private[flume] class FlumeTestUtils { def writeInput(input: JList[String], enableCompression: Boolean): Unit = { val testAddress = new InetSocketAddress("localhost", testPort) - val inputEvents = input.map { item => + val inputEvents = input.asScala.map { item => val event = new AvroFlumeEvent event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event.setHeaders(Collections.singletonMap("test", "header")) event } @@ -88,7 +89,7 @@ private[flume] class FlumeTestUtils { } // Send data - val status = client.appendBatch(inputEvents.toList) + val status = client.appendBatch(inputEvents.asJava) if (status != avro.Status.OK) { throw new AssertionError("Sent events unsuccessfully") } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 095bfb0c73a9..c719b80aca7e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -21,7 +21,7 @@ import java.net.InetSocketAddress import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.api.java.function.PairFunction import org.apache.spark.api.python.PythonRDD @@ -247,7 +247,7 @@ object FlumeUtils { * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and * function so that it can be easily instantiated and called from Python's FlumeUtils. */ -private class FlumeUtilsPythonHelper { +private[flume] class FlumeUtilsPythonHelper { def createStream( jssc: JavaStreamingContext, @@ -268,8 +268,8 @@ private class FlumeUtilsPythonHelper { maxBatchSize: Int, parallelism: Int ): JavaPairDStream[Array[Byte], Array[Byte]] = { - assert(hosts.length == ports.length) - val addresses = hosts.zip(ports).map { + assert(hosts.size() == ports.size()) + val addresses = hosts.asScala.zip(ports.asScala).map { case (host, port) => new InetSocketAddress(host, port) } val dstream = FlumeUtils.createPollingStream( @@ -286,7 +286,7 @@ private object FlumeUtilsPythonHelper { val output = new DataOutputStream(byteStream) try { output.writeInt(map.size) - map.foreach { kv => + map.asScala.foreach { kv => PythonRDD.writeUTF(kv._1.toString, output) PythonRDD.writeUTF(kv._2.toString, output) } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 91d63d49dbec..bfe7548d4f50 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.streaming.flume import java.util.concurrent._ -import java.util.{List => JList, Map => JMap} +import java.util.{Collections, List => JList, Map => JMap} -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 @@ -77,7 +76,7 @@ private[flume] class PollingFlumeTestUtils { /** * Start 2 sinks and return the ports */ - def startMultipleSinks(): JList[Int] = { + def startMultipleSinks(): Seq[Int] = { channels.clear() sinks.clear() @@ -149,7 +148,7 @@ private[flume] class PollingFlumeTestUtils { var counter = 0 for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { val eventBodyToVerify = s"${channels(k).getName}-$i" - val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") var found = false var j = 0 while (j < eventSize && !found) { @@ -195,7 +194,7 @@ private[flume] class PollingFlumeTestUtils { tx.begin() for (j <- 0 until eventsPerBatch) { channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), - Map[String, String](s"test-$t" -> "header"))) + Collections.singletonMap(s"test-$t", "header"))) t += 1 } tx.commit() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 1a900007b696..79077e4a49e1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -37,7 +37,7 @@ class TestOutputStream[T: ClassTag](parent: DStream[T], extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d5f9a0aa38f9..bb951a6ef100 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -19,16 +19,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} @@ -116,11 +116,11 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log // The eventually is required to ensure that all data in the batch has been processed. eventually(timeout(10 seconds), interval(100 milliseconds)) { val flattenOutputBuffer = outputBuffer.flatten - val headers = flattenOutputBuffer.map(_.event.getHeaders.map { - case kv => (kv._1.toString, kv._2.toString) - }).map(mapAsJavaMap) - val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) - utils.assertOutput(headers, bodies) + val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + case (key, value) => (key.toString, value.toString) + }).map(_.asJava) + val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody)) + utils.assertOutput(headers.asJava, bodies.asJava) } } finally { ssc.stop() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 5bc4cdf65306..b29e591c0737 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} @@ -54,7 +54,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w val outputBuffer = startContext(utils.getTestPort(), testCompression) eventually(timeout(10 seconds), interval(100 milliseconds)) { - utils.writeInput(input, testCompression) + utils.writeInput(input.asJava, testCompression) } eventually(timeout(10 seconds), interval(100 milliseconds)) { @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) output should be (input) } } finally { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 977514fa5a1e..a9ed39ef8c9a 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -47,6 +47,90 @@ ${project.version} provided + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + net.jpountz.lz4 + lz4 + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + @@ -58,7 +142,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kafka-assembly-${project.version}.jar *:* diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e..79258c126e04 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -87,14 +87,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala index 5a74febb4bd4..9159051ba06e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -20,11 +20,9 @@ package org.apache.spark.streaming.kafka import org.apache.spark.annotation.Experimental /** - * :: Experimental :: - * Represent the host and port info for a Kafka broker. - * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + * Represents the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID. */ -@Experimental final class Broker private( /** Broker's hostname */ val host: String, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 48a1933d92f8..8a087474d316 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -57,11 +58,11 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R -) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](ssc_) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) @@ -71,14 +72,40 @@ class DirectKafkaInputDStream[ protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, context.graph.batchDuration))) + } else { + None + } + } + protected val kc = new KafkaCluster(kafkaParams) - protected val maxMessagesPerPartition: Option[Long] = { - val ratePerSec = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - if (ratePerSec > 0) { + protected def maxMessagesPerPartition: Option[Long] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val numPartitions = currentOffsets.keys.size + + val effectiveRateLimitPerPartition = estimatedRateLimit + .filter(_ > 0) + .map { limit => + if (maxRateLimitPerPartition > 0) { + Math.min(maxRateLimitPerPartition, (limit / numPartitions)) + } else { + limit / numPartitions + } + }.getOrElse(maxRateLimitPerPartition) + + if (effectiveRateLimitPerPartition > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * ratePerSec).toLong) + Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) } else { None } @@ -170,11 +197,18 @@ class DirectKafkaInputDStream[ val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => - logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") - generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( - context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) } } } + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 04b2dc10d39e..38730fecf332 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -48,7 +48,7 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 1a9d78c0d4f5..ea5f842c6caf 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -197,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index b608b7595272..45a6982b9afe 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,27 +20,26 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap} -import java.util.Properties import java.util.concurrent.TimeoutException +import java.util.{Map => JMap, Properties} import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.language.postfixOps import scala.util.control.NonFatal import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.{ZKStringSerializer, ZkUtils} -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.{Logging, SparkConf} import org.apache.spark.streaming.Time import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -48,12 +47,12 @@ import org.apache.spark.util.Utils * * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. */ -private class KafkaTestUtils extends Logging { +private[kafka] class KafkaTestUtils extends Logging { // Zookeeper related configurations private val zkHost = "localhost" private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 + private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 6000 private var zookeeper: EmbeddedZookeeper = _ @@ -152,7 +151,7 @@ private class KafkaTestUtils extends Logging { } } - /** Create a Kafka topic and wait until it propagated to the whole cluster */ + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated @@ -161,8 +160,7 @@ private class KafkaTestUtils extends Logging { /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { - import scala.collection.JavaConversions._ - sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) } /** Send the messages to the Kafka broker */ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index f3b01bd60b17..fe572220528d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,29 +17,29 @@ package org.apache.spark.streaming.kafka -import java.lang.{Integer => JInt} -import java.lang.{Long => JLong} -import java.util.{Map => JMap} -import java.util.{Set => JSet} -import java.util.{List => JList} +import java.io.OutputStream +import java.lang.{Integer => JInt, Long => JLong} +import java.util.{List => JList, Map => JMap, Set => JSet} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import com.google.common.base.Charsets.UTF_8 import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} object KafkaUtils { /** @@ -51,6 +51,7 @@ object KafkaUtils { * in its own thread * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( ssc: StreamingContext, @@ -74,6 +75,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel Storage level to use for storing the received objects + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, @@ -93,6 +99,7 @@ object KafkaUtils { * @param groupId The group id for this consumer * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -100,7 +107,7 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*)) } /** @@ -111,6 +118,7 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -119,7 +127,7 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -135,6 +143,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread * @param storageLevel RDD storage level. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, @@ -153,7 +166,10 @@ object KafkaUtils { implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( - jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + jssc.ssc, + kafkaParams.asScala.toMap, + Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** get leaders for the given offset ranges, or throw an exception */ @@ -185,6 +201,27 @@ object KafkaUtils { } } + private[kafka] def getFromOffsets( + kc: KafkaCluster, + kafkaParams: Map[String, String], + topics: Set[String] + ): Map[TopicAndPartition, Long] = { + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val result = for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + } + KafkaCluster.checkErrors(result) + } + /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -195,8 +232,12 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -214,7 +255,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -229,8 +269,13 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -250,7 +295,7 @@ object KafkaUtils { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) - }.toMap + } } val cleanedHandler = sc.clean(messageHandler) checkOffsets(kc, offsetRanges) @@ -267,8 +312,16 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @param keyClass type of Kafka message key + * @param valueClass type of Kafka message value + * @param keyDecoderClass type of Kafka message key decoder + * @param valueDecoderClass type of Kafka message value decoder + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, keyClass: Class[K], @@ -283,11 +336,10 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) new JavaPairRDD(createRDD[K, V, KD, VD]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges)) } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -302,8 +354,13 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, keyClass: Class[K], @@ -321,13 +378,12 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) - val leaderMap = Map(leaders.toSeq: _*) + val leaderMap = Map(leaders.asScala.toSeq: _*) createRDD[K, V, KD, VD, R]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges, leaderMap, messageHandler.call(_)) } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -356,8 +412,13 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -375,7 +436,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -404,8 +464,12 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -417,27 +481,12 @@ object KafkaUtils { ): InputDStream[(K, V)] = { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) val kc = new KafkaCluster(kafkaParams) - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - - val result = for { - topicPartitions <- kc.getPartitions(topics).right - leaderOffsets <- (if (reset == Some("smallest")) { - kc.getEarliestLeaderOffsets(topicPartitions) - } else { - kc.getLatestLeaderOffsets(topicPartitions) - }).right - } yield { - val fromOffsets = leaderOffsets.map { case (tp, lo) => - (tp, lo.offset) - } - new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( - ssc, kafkaParams, fromOffsets, messageHandler) - } - KafkaCluster.checkErrors(result) + val fromOffsets = getFromOffsets(kc, kafkaParams, topics) + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -471,8 +520,13 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -492,14 +546,13 @@ object KafkaUtils { val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + Map(kafkaParams.asScala.toSeq: _*), + Map(fromOffsets.asScala.mapValues(_.longValue()).toSeq: _*), cleanedHandler ) } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -532,8 +585,12 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -549,8 +606,8 @@ object KafkaUtils { implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) createDirectStream[K, V, KD, VD]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Set(topics.toSeq: _*) + Map(kafkaParams.asScala.toSeq: _*), + Set(topics.asScala.toSeq: _*) ) } } @@ -564,7 +621,9 @@ object KafkaUtils { * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() * takes care of known parameters instead of passing them from Python */ -private class KafkaUtilsPythonHelper { +private[kafka] class KafkaUtilsPythonHelper { + import KafkaUtilsPythonHelper._ + def createStream( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], @@ -581,86 +640,92 @@ private class KafkaUtilsPythonHelper { storageLevel) } - def createRDD( + def createRDDWithoutMessageHandler( jsc: JavaSparkContext, kafkaParams: JMap[String, String], offsetRanges: JList[OffsetRange], - leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler)) + } + + def createRDDWithMessageHandler( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata( + mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler). + mapPartitions(picklerIterator) + new JavaRDD(rdd) + } + + private def createRDD[V: ClassTag]( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = { + KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jsc.sc, + kafkaParams.asScala.toMap, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders.asScala.toMap, + messageHandler + ) + } + + def createDirectStreamWithoutMessageHandler( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler)) + } - val jrdd = KafkaUtils.createRDD[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jsc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), - leaders, - messageHandler - ) - new JavaPairRDD(jrdd.rdd) + def createDirectStreamWithMessageHandler( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler). + mapPartitions(picklerIterator) + new JavaDStream(stream) } - def createDirectStream( + private def createDirectStream[V: ClassTag]( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong] - ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { - - if (!fromOffsets.isEmpty) { - import scala.collection.JavaConversions._ - val topicsFromOffsets = fromOffsets.keySet().map(_.topic) - if (topicsFromOffsets != topics.toSet) { - throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = { + + val currentFromOffsets = if (!fromOffsets.isEmpty) { + val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) + if (topicsFromOffsets != topics.asScala.toSet) { + throw new IllegalStateException( + s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } - } - - if (fromOffsets.isEmpty) { - KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - kafkaParams, - topics) + Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*) } else { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } - - val jstream = KafkaUtils.createDirectStream[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - fromOffsets, - messageHandler) - new JavaPairInputDStream(jstream.inputDStream) + val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*)) + KafkaUtils.getFromOffsets( + kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*)) } + + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jssc.ssc, + Map(kafkaParams.asScala.toSeq: _*), + Map(currentFromOffsets.toSeq: _*), + messageHandler) } def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong @@ -681,6 +746,60 @@ private class KafkaUtilsPythonHelper { "with this RDD, please call this method only on a Kafka RDD.") val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] - kafkaRDD.offsetRanges.toSeq + kafkaRDD.offsetRanges.toSeq.asJava + } +} + +private object KafkaUtilsPythonHelper { + private var initialized = false + + def initialize(): Unit = { + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new PythonMessageAndMetadataPickler().register() + initialized = true + } + } + } + + initialize() + + def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = { + new SerDeUtil.AutoBatchedPickler(iter) + } + + case class PythonMessageAndMetadata( + topic: String, + partition: JInt, + offset: JLong, + key: Array[Byte], + message: Array[Byte]) + + class PythonMessageAndMetadataPickler extends IObjectPickler { + private val module = "pyspark.streaming.kafka" + + def register(): Unit = { + Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this) + Pickler.registerCustomPickler(this.getClass, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler) { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8)) + } else { + pickler.save(this) + val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata] + out.write(Opcodes.MARK) + pickler.save(msgAndMetaData.topic) + pickler.save(msgAndMetaData.partition) + pickler.save(msgAndMetaData.offset) + pickler.save(msgAndMetaData.key) + pickler.save(msgAndMetaData.message) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index f326e7f1f6f8..d9b856e4697a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -19,11 +19,8 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: - * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). * {{{ @@ -33,25 +30,22 @@ import org.apache.spark.annotation.Experimental * } * }}} */ -@Experimental trait HasOffsetRanges { def offsetRanges: Array[OffsetRange] } /** - * :: Experimental :: * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset */ -@Experimental final class OffsetRange private( - /** Kafka topic name */ val topic: String, - /** Kafka partition id */ val partition: Int, - /** inclusive starting offset */ val fromOffset: Long, - /** exclusive ending offset */ val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple @@ -84,10 +78,8 @@ final class OffsetRange private( } /** - * :: Experimental :: * Companion object the provides methods to create instances of [[OffsetRange]]. */ -@Experimental object OffsetRange { def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = new OffsetRange(topic, partition, fromOffset, untilOffset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 75f0dfc22b9d..764d170934aa 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -96,7 +96,7 @@ class ReliableKafkaReceiver[ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() // Initialize the block generator for storing Kafka message. - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 02cd24a35906..fbdfbf7e509b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -70,16 +70,16 @@ public void testKafkaStream() throws InterruptedException { final String topic1 = "topic1"; final String topic2 = "topic2"; // hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613c..afcc6cfccd39 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b76..1e69de46cd35 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 5b3c79444aa6..02225d5aa7cc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.rate.RateEstimator + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -350,6 +353,77 @@ class DirectKafkaStreamSuite ssc.stop() } + test("using rate controller") { + val topic = "backpressure" + val topicPartition = TopicAndPartition(topic, 0) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 100 + val estimator = new ConstantEstimator(100) + val messageKeys = (1 to 200).map(_.toString) + val messages = messageKeys.map((_, 1)).toMap + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + } + } + + val collectedData = + new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + + // Used for assertion failure messages. + def dataToString: String = + collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData += data + } + + ssc.start() + + // Try different rate limits. + // Send data to Kafka and wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + kafkaTestUtils.sendMessages(topic, messages) + eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { @@ -381,3 +455,18 @@ object DirectKafkaStreamSuite { } } } + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 000000000000..89713a28ca6a --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,175 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784..59fba8b826b4 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -58,25 +58,47 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.activemq activemq-core 5.7.0 test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 000000000000..c110b01b34e1 --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 7c2f18cb35bd..116c170489e9 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class MQTTInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba3..7b8d56d6faf2 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -21,8 +21,8 @@ import scala.reflect.ClassTag import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object MQTTUtils { /** @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private[mqtt] class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869b..a6a9249db8ed 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 000000000000..1618e2c088b7 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private[mqtt] class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b5..087270de90b3 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -51,7 +51,7 @@ org.twitter4j twitter4j-stream - 3.0.3 + 4.0.4 org.scalacheck @@ -59,14 +59,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 7cf02d85d73d..9a85a6597c27 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class TwitterInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel @@ -87,7 +87,7 @@ class TwitterReceiver( val query = new FilterQuery if (filters.size > 0) { - query.track(filters.toArray) + query.track(filters.mkString(",")) newTwitterStream.filter(query) } else { newTwitterStream.sample() diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531..26ec8af455bc 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d4366..02d6b8128157 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -58,14 +58,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 0469d0af8864..4ea218eaa4de 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -18,15 +18,17 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream} +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { @@ -75,7 +77,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +102,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +126,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3636a9037d43..4ce90e75fd35 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -59,14 +59,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce5..14975265ab2c 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 73091cfe2c09..e8a0dfc0f0a5 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,12 +28,14 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; /** * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 @@ -357,6 +359,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( @@ -412,9 +439,14 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - expected.forEach((List list) -> Collections.sort(list)); - actual.forEach((List list) -> Collections.sort(list)); - Assert.assertEquals(expected, actual); + expected.forEach(list -> Collections.sort(list)); + List> sortedActual = new ArrayList<>(); + actual.forEach(list -> { + List sortedList = new ArrayList<>(list); + Collections.sort(sortedList); + sortedActual.add(sortedList); + }); + Assert.assertEquals(expected, sortedActual); } @Test @@ -831,4 +863,44 @@ public void testFlatMapValues() { Assert.assertEquals(expected, result); } + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testMapWithStateAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + JavaMapWithStateDStream stateDstream = + wordsDstream.mapWithState( + StateSpec. function((time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + JavaMapWithStateDStream stateDstream2 = + wordsDstream.mapWithState( + StateSpec.function((key, value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream mappedDStream = stateDstream2.stateSnapshots(); + } } diff --git a/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala new file mode 100644 index 000000000000..fa0681db4108 --- /dev/null +++ b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Test cases where JDK8-compiled Scala user code is used with Spark. + */ +class JDK8ScalaSuite extends SparkFunSuite with SharedSparkContext { + test("basic RDD closure test (SPARK-6152)") { + sc.parallelize(1 to 1000).map(x => x * x).count() + } +} diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 70d2c9c58f54..61ba4787fbf9 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -47,6 +47,85 @@ ${project.version} provided + + + com.fasterxml.jackson.core + jackson-databind + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + @@ -58,7 +137,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar *:* diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9a..519a920279c9 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl @@ -64,6 +64,12 @@ aws-java-sdk ${aws.java.sdk.version} + + com.amazonaws + amazon-kinesis-producer + ${aws.kinesis.producer.version} + test + org.mockito mockito-core @@ -75,9 +81,8 @@ test - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 8f144a4d974a..691c1790b207 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming.kinesis -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ import org.apache.spark._ @@ -37,16 +39,18 @@ case class SequenceNumberRange( /** Class representing an array of Kinesis sequence number ranges */ private[kinesis] -case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) { +case class SequenceNumberRanges(ranges: Seq[SequenceNumberRange]) { def isEmpty(): Boolean = ranges.isEmpty + def nonEmpty(): Boolean = ranges.nonEmpty + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") } private[kinesis] object SequenceNumberRanges { def apply(range: SequenceNumberRange): SequenceNumberRanges = { - new SequenceNumberRanges(Array(range)) + new SequenceNumberRanges(Seq(range)) } } @@ -65,16 +69,17 @@ class KinesisBackedBlockRDDPartition( * sequence numbers of the corresponding blocks. */ private[kinesis] -class KinesisBackedBlockRDD( - sc: SparkContext, - regionId: String, - endpointUrl: String, +class KinesisBackedBlockRDD[T: ClassTag]( + @transient sc: SparkContext, + val regionName: String, + val endpointUrl: String, @transient blockIds: Array[BlockId], - @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient isBlockIdValid: Array[Boolean] = Array.empty, - retryTimeoutMs: Int = 10000, - awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[Array[Byte]](sc, blockIds) { + val retryTimeoutMs: Int = 10000, + val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, + val awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[T](sc, blockIds) { require(blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") @@ -88,23 +93,23 @@ class KinesisBackedBlockRDD( } } - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + override def compute(split: Partition, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] val blockId = partition.blockId - def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + def getBlockFromBlockManager(): Option[Iterator[T]] = { logDebug(s"Read partition data of $this from block manager, block $blockId") - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) } - def getBlockFromKinesis(): Iterator[Array[Byte]] = { - val credenentials = awsCredentialsOption.getOrElse { + def getBlockFromKinesis(): Iterator[T] = { + val credentials = awsCredentialsOption.getOrElse { new DefaultAWSCredentialsProviderChain().getCredentials() } partition.seqNumberRanges.ranges.iterator.flatMap { range => - new KinesisSequenceRangeIterator( - credenentials, endpointUrl, regionId, range, retryTimeoutMs) + new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, + range, retryTimeoutMs).map(messageHandler) } } if (partition.isBlockIdValid) { @@ -127,8 +132,7 @@ class KinesisSequenceRangeIterator( endpointUrl: String, regionId: String, range: SequenceNumberRange, - retryTimeoutMs: Int - ) extends NextIterator[Array[Byte]] with Logging { + retryTimeoutMs: Int) extends NextIterator[Record] with Logging { private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName @@ -140,8 +144,8 @@ class KinesisSequenceRangeIterator( client.setEndpoint(endpointUrl, "kinesis", regionId) - override protected def getNext(): Array[Byte] = { - var nextBytes: Array[Byte] = null + override protected def getNext(): Record = { + var nextRecord: Record = null if (toSeqNumberReceived) { finished = true } else { @@ -168,10 +172,7 @@ class KinesisSequenceRangeIterator( } else { // Get the record, copy the data into a byte array and remember its sequence number - val nextRecord: Record = internalIterator.next() - val byteBuffer = nextRecord.getData() - nextBytes = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(nextBytes) + nextRecord = internalIterator.next() lastSeqNumber = nextRecord.getSequenceNumber() // If the this record's sequence number matches the stopping sequence number, then make sure @@ -180,9 +181,8 @@ class KinesisSequenceRangeIterator( toSeqNumberReceived = true } } - } - nextBytes + nextRecord } override protected def close(): Unit = { @@ -211,7 +211,10 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) + (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala deleted file mode 100644 index 83a453755951..000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import org.apache.spark.Logging -import org.apache.spark.streaming.Duration -import org.apache.spark.util.{Clock, ManualClock, SystemClock} - -/** - * This is a helper class for managing checkpoint clocks. - * - * @param checkpointInterval - * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) - */ -private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, - currentClock: Clock = new SystemClock()) - extends Logging { - - /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ - val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) - - /** - * Check if it's time to checkpoint based on the current time and the derived time - * for the next checkpoint - * - * @return true if it's time to checkpoint - */ - def shouldCheckpoint(): Boolean = { - new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() - } - - /** - * Advance the checkpoint clock by the checkpoint interval. - */ - def advanceCheckpoint(): Unit = { - checkpointClock.advance(checkpointInterval.milliseconds) - } -} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala new file mode 100644 index 000000000000..1ca6d4302c2b --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} + +/** + * This is a helper class for managing Kinesis checkpointing. + * + * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint + * @param checkpointInterval How frequently we will checkpoint to DynamoDB + * @param workerId Worker Id of KCL worker for logging purposes + * @param clock In order to use ManualClocks for the purpose of testing + */ +private[kinesis] class KinesisCheckpointer( + receiver: KinesisReceiver[_], + checkpointInterval: Duration, + workerId: String, + clock: Clock = new SystemClock) extends Logging { + + // a map from shardId's to checkpointers + private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + + private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() + + private val checkpointerThread: RecurringTimer = startCheckpointerThread() + + /** Update the checkpointer instance to the most recent one for the given shardId. */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + checkpointers.put(shardId, checkpointer) + } + + /** + * Stop tracking the specified shardId. + * + * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], + * we will use that to make the final checkpoint. If `null` is provided, we will not make the + * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + synchronized { + checkpointers.remove(shardId) + checkpoint(shardId, checkpointer) + } + } + + /** Perform the checkpoint. */ + private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + try { + if (checkpointer != null) { + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + val lastSeqNum = lastCheckpointedSeqNums.get(shardId) + // Kinesis sequence numbers are monotonically increasing strings, therefore we can do + // safely do the string comparison + if (lastSeqNum == null || latestSeqNum > lastSeqNum) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + + s" $latestSeqNum for shardId $shardId") + lastCheckpointedSeqNums.put(shardId, latestSeqNum) + } + } + } else { + logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.") + } + } catch { + case NonFatal(e) => + logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e) + } + } + + /** Checkpoint the latest saved sequence numbers for all active shardId's. */ + private def checkpointAll(): Unit = synchronized { + // if this method throws an exception, then the scheduled task will not run again + try { + val shardIds = checkpointers.keys() + while (shardIds.hasMoreElements) { + val shardId = shardIds.nextElement() + checkpoint(shardId, checkpointers.get(shardId)) + } + } catch { + case NonFatal(e) => + logWarning("Failed to checkpoint to DynamoDB.", e) + } + } + + /** + * Start the checkpointer thread with the given checkpoint duration. + */ + private def startCheckpointerThread(): RecurringTimer = { + val period = checkpointInterval.milliseconds + val threadName = s"Kinesis Checkpointer - Worker $workerId" + val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) + timer.start() + logDebug(s"Started checkpointer thread: $threadName") + timer + } + + /** + * Shutdown the checkpointer. Should be called on the onStop of the Receiver. + */ + def shutdown(): Unit = { + // the recurring timer checkpoints for us one last time. + checkpointerThread.stop(interruptTimer = false) + checkpointers.clear() + lastCheckpointedSeqNums.clear() + logInfo("Successfully shutdown Kinesis Checkpointer.") + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala new file mode 100644 index 000000000000..72ab6357a53b --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.reflect.ClassTag + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.{Duration, StreamingContext, Time} + +private[kinesis] class KinesisInputDStream[T: ClassTag]( + @transient _ssc: StreamingContext, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointAppName: String, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends ReceiverInputDStream[T](_ssc) { + + private[streaming] + override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { + + // This returns true even for when blockInfos is empty + val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty) + + if (allBlocksHaveRanges) { + // Create a KinesisBackedBlockRDD, even when there are no blocks + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + val seqNumRanges = blockInfos.map { + _.metadataOption.get.asInstanceOf[SequenceNumberRanges] }.toArray + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " + + s"seq number ranges: ${seqNumRanges.mkString(", ")} ") + new KinesisBackedBlockRDD( + context.sc, regionName, endpointUrl, blockIds, seqNumRanges, + isBlockIdValid = isBlockIdValid, + retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, + messageHandler = messageHandler, + awsCredentialsOption = awsCredentialsOption) + } else { + logWarning("Kinesis sequence number information was not present with some block metadata," + + " it may not be possible to recover from failures") + super.createBlockRDD(time, blockInfos) + } + } + + override def getReceiver(): Receiver[T] = { + new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, + checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1a8a4cecc114..80edda59e171 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -17,19 +17,22 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.model.Record -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils - +import org.apache.spark.Logging private[kinesis] case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) @@ -42,42 +45,53 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers to run within a - * Spark Executor. * - * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams - * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing - * DynamoDB table with the same name this Kinesis application. + * The way this Receiver works is as follows: + * + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. + * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Region name used by the Kinesis Client Library for * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies * the credentials */ -private[kinesis] class KinesisReceiver( - appName: String, - streamName: String, +private[kinesis] class KinesisReceiver[T]( + val streamName: String, endpointUrl: String, regionName: String, initialPositionInStream: InitialPositionInStream, + checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, - awsCredentialsOption: Option[SerializableAWSCredentials] - ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + messageHandler: Record => T, + awsCredentialsOption: Option[SerializableAWSCredentials]) + extends Receiver[T](storageLevel) with Logging { receiver => /* * ================================================================================= @@ -90,7 +104,7 @@ private[kinesis] class KinesisReceiver( * workerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ - private var workerId: String = null + @volatile private var workerId: String = null /** * Worker is the core client abstraction from the Kinesis Client Library (KCL). @@ -98,22 +112,45 @@ private[kinesis] class KinesisReceiver( * Each shard is assigned its own IRecordProcessor and the worker run multiple such * processors. */ - private var worker: Worker = null + @volatile private var worker: Worker = null + @volatile private var workerThread: Thread = null - /** Thread running the worker */ - private var workerThread: Thread = null + /** BlockGenerator used to generates blocks out of Kinesis data */ + @volatile private var blockGenerator: BlockGenerator = null + + /** + * Sequence number ranges added to the current block being generated. + * Accessing and updating of this map is synchronized by locks in BlockGenerator. + */ + private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] + + /** Sequence number ranges of data added to each generated block */ + private val blockIdToSeqNumRanges = new ConcurrentHashMap[StreamBlockId, SequenceNumberRanges] + + /** + * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. + */ + @volatile private var kinesisCheckpointer: KinesisCheckpointer = null + + /** + * Latest sequence number ranges that have been stored successfully. + * This is used for checkpointing through KCL */ + private val shardIdToLatestStoredSeqNum = new ConcurrentHashMap[String, String] /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) + workerId = Utils.localHostName() + ":" + UUID.randomUUID() + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) .withTaskBackoffTimeMillis(500) @@ -126,8 +163,8 @@ private[kinesis] class KinesisReceiver( * We're using our custom KinesisRecordProcessor in this case. */ val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, - workerId, new KinesisCheckpointState(checkpointInterval)) + override def createProcessor: IRecordProcessor = + new KinesisRecordProcessor(receiver, workerId) } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) @@ -141,6 +178,10 @@ private[kinesis] class KinesisReceiver( } } } + + blockIdToSeqNumRanges.clear() + blockGenerator.start() + workerThread.setName(s"Kinesis Receiver ${streamId}") workerThread.setDaemon(true) workerThread.start() @@ -163,6 +204,98 @@ private[kinesis] class KinesisReceiver( logInfo(s"Stopped receiver for workerId $workerId") } workerId = null + if (kinesisCheckpointer != null) { + kinesisCheckpointer.shutdown() + kinesisCheckpointer = null + } + } + + /** Add records of the given shard to the current block being generated */ + private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { + if (records.size > 0) { + val dataIterator = records.iterator().asScala.map(messageHandler) + val metadata = SequenceNumberRange(streamName, shardId, + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) + } + } + + /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ + private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { + Option(shardIdToLatestStoredSeqNum.get(shardId)) + } + + /** + * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the + * given shardId. + */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.setCheckpointer(shardId, checkpointer) + } + + /** + * Remove the checkpointer for the given shardId. The provided checkpointer will be used to + * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not + * checkpoint. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) + } + + /** + * Remember the range of sequence numbers that was added to the currently active block. + * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. + */ + private def rememberAddedRange(range: SequenceNumberRange): Unit = { + seqNumRangesInCurrentBlock += range + } + + /** + * Finalize the ranges added to the block that was active and prepare the ranges buffer + * for next block. Internally, this is synchronized with `rememberAddedRange()`. + */ + private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { + blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) + seqNumRangesInCurrentBlock.clear() + logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") + } + + /** Store the block along with its associated ranges */ + private def storeBlockWithRanges( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { + val rangesToReportOption = Option(blockIdToSeqNumRanges.remove(blockId)) + if (rangesToReportOption.isEmpty) { + stop("Error while storing block into Spark, could not find sequence number ranges " + + s"for block $blockId") + return + } + + val rangesToReport = rangesToReportOption.get + var attempt = 0 + var stored = false + var throwable: Throwable = null + while (!stored && attempt <= 3) { + try { + store(arrayBuffer, rangesToReport) + stored = true + } catch { + case NonFatal(th) => + attempt += 1 + throwable = th + } + } + if (!stored) { + stop("Error while storing block into Spark", throwable) + } + + // Update the latest sequence number that have been successfully stored for each shard + // Note that we are doing this sequentially because the array of sequence number ranges + // is assumed to be + rangesToReport.ranges.foreach { range => + shardIdToLatestStoredSeqNum.put(range.shardId, range.toSeqNumber) + } } /** @@ -182,4 +315,46 @@ private[kinesis] class KinesisReceiver( new DefaultAWSCredentialsProviderChain() } } + + + /** + * Class to handle blocks generated by this receiver's block generator. Specifically, in + * the context of the Kinesis Receiver, this handler does the following. + * + * - When an array of records is added to the current active block in the block generator, + * this handler keeps track of the corresponding sequence number range. + * - When the currently active block is ready to sealed (not more records), this handler + * keep track of the list of ranges added into this block in another H + */ + private class GeneratedBlockHandler extends BlockGeneratorListener { + + /** + * Callback method called after a data item is added into the BlockGenerator. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onAddData(data: Any, metadata: Any): Unit = { + rememberAddedRange(metadata.asInstanceOf[SequenceNumberRange]) + } + + /** + * Callback method called after a block has been generated. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onGenerateBlock(blockId: StreamBlockId): Unit = { + finalizeRangesForCurrentBlock(blockId) + } + + /** Callback method called when a block is ready to be pushed / stored. */ + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + storeBlockWithRanges(blockId, + arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]]) + } + + /** Callback called in case of any error in internal of the BlockGenerator */ + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index fe9e3a0c793e..b5b76cb92d86 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -18,39 +18,33 @@ package org.apache.spark.streaming.kinesis import java.util.List -import scala.collection.JavaConversions.asScalaBuffer import scala.util.Random +import scala.util.control.NonFatal -import org.apache.spark.Logging - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration + /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create - * multiple Receivers. + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes - * @param checkpointState represents the checkpoint state including the next checkpoint time. - * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor( - receiver: KinesisReceiver, - workerId: String, - checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) + extends IRecordProcessor with Logging { - // shardId to be populated during initialize() + // shardId populated during initialize() + @volatile private var shardId: String = _ /** @@ -75,54 +69,18 @@ private[kinesis] class KinesisRecordProcessor( override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { - /* - * Notes: - * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - * 3) For performance, the BlockGenerator is asynchronously queuing elements within its - * memory before creating blocks. This prevents the small block scenario, but requires - * that you register callbacks to know when a block has been generated and stored - * (WAL is sufficient for storage) before can checkpoint back to the source. - */ - batch.foreach(record => receiver.store(record.getData().array())) - - logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - - /* - * Checkpoint the sequence number of the last record successfully processed/stored - * in the batch. - * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the - * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. - * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). - * However, if the worker dies unexpectedly, a checkpoint may not happen. - * This could lead to records being processed more than once. - */ - if (checkpointState.shouldCheckpoint()) { - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) - - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() - - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + - s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") - } + receiver.addRecords(shardId, batch) + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + receiver.setCheckpointer(shardId, checkpointer) } catch { - case e: Throwable => { + case NonFatal(e) => { /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed * more than once. */ logError(s"Exception: WorkerId $workerId encountered and exception while storing " + - " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + s" or checkpointing a batch for workerId $workerId and shardId $shardId.", e) /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e @@ -130,7 +88,7 @@ private[kinesis] class KinesisRecordProcessor( } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } @@ -154,19 +112,18 @@ private[kinesis] class KinesisRecordProcessor( * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + receiver.removeCheckpointer(shardId, checkpointer) /* - * ZOMBIE Use Case. NoOp. + * ZOMBIE Use Case or Unknown reason. NoOp. * No checkpoint because other workers may have taken over and already started processing * the same records. * This may lead to records being processed more than once. */ - case ShutdownReason.ZOMBIE => - - /* Unknown reason. NoOp */ case _ => + receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint } + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 255ac27f793b..0ace453ee928 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} @@ -34,24 +35,14 @@ import com.amazonaws.services.kinesis.model._ import org.apache.spark.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ -private class KinesisTestUtils(val endpointUrl: String, _regionName: String) extends Logging { - - def this() { - this("https://kinesis.us-west-2.amazonaws.com", "") - } - - def this(endpointUrl: String) { - this(endpointUrl, "") - } - - val regionName = if (_regionName.length == 0) { - RegionUtils.getRegionByEndpoint(endpointUrl).getName() - } else { - RegionUtils.getRegion(_regionName).getName() - } +private[kinesis] class KinesisTestUtils extends Logging { + val endpointUrl = KinesisTestUtils.endpointUrl + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() val streamShardCount = 2 private val createStreamTimeoutSeconds = 300 @@ -63,7 +54,7 @@ private class KinesisTestUtils(val endpointUrl: String, _regionName: String) ext @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -75,17 +66,25 @@ private class KinesisTestUtils(val endpointUrl: String, _regionName: String) ext new DynamoDB(dynamoDBClient) } + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } + } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName } def createStream(): Unit = { - logInfo("Creating stream") require(!streamCreated, "Stream already created") _streamName = findNonExistentStreamName() // Create a stream. The number of shards determines the provisioned throughput. + logInfo(s"Creating stream ${_streamName}") val createStreamRequest = new CreateStreamRequest() createStreamRequest.setStreamName(_streamName) createStreamRequest.setShardCount(2) @@ -94,31 +93,17 @@ private class KinesisTestUtils(val endpointUrl: String, _regionName: String) ext // The stream is now being created. Wait for it to become active. waitForStreamToBeActive(_streamName) streamCreated = true - logInfo("Created stream") + logInfo(s"Created stream ${_streamName}") } /** * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard */ - def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(ByteBuffer.wrap(str.getBytes())) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -127,7 +112,7 @@ private class KinesisTestUtils(val endpointUrl: String, _regionName: String) ext * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + pushData(testData.asScala, aggregate = false) } def deleteStream(): Unit = { @@ -191,9 +176,38 @@ private class KinesisTestUtils(val endpointUrl: String, _regionName: String) ext private[kinesis] object KinesisTestUtils { - val envVarName = "ENABLE_KINESIS_TESTS" + val envVarNameForEnablingTests = "ENABLE_KINESIS_TESTS" + val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" + val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + + lazy val shouldRunTests = { + val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") + if (isEnvSet) { + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println( + s""" + |Kinesis tests that actually send data has been enabled by setting the environment + |variable $envVarNameForEnablingTests to 1. This will create Kinesis Streams and + |DynamoDB tables in AWS. Please be aware that this may incur some AWS costs. + |By default, the tests use the endpoint URL $defaultEndpointUrl to create Kinesis streams. + |To change this endpoint URL to a different region, you can set the environment variable + |$endVarNameForEndpoint to the desired endpoint URL + |(e.g. $endVarNameForEndpoint="https://kinesis.us-west-2.amazonaws.com"). + """.stripMargin) + // scalastyle:on println + } + isEnvSet + } - val shouldRunTests = sys.env.get(envVarName) == Some("1") + lazy val endpointUrl = { + val url = sys.env.getOrElse(endVarNameForEndpoint, defaultEndpointUrl) + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println(s"Using endpoint URL $url for creating Kinesis streams for tests.") + // scalastyle:on println + url + } def isAWSCredentialsPresent: Boolean = { Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess @@ -205,7 +219,42 @@ private[kinesis] object KinesisTestUtils { Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { case Success(cred) => cred case Failure(e) => - throw new Exception("Kinesis tests enabled, but could get not AWS credentials") + throw new Exception( + s""" + |Kinesis tests enabled using environment variable $envVarNameForEnablingTests + |but could not find AWS credentials. Please follow instructions in AWS documentation + |to set the credentials in your system such that the DefaultAWSCredentialsProviderChain + |can find the credentials. + """.stripMargin) + } + } +} + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) } + + shardIdToSeqNumbers.toMap } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 7dab17eba848..2de6195716e5 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,16 +16,120 @@ */ package org.apache.spark.streaming.kinesis +import scala.reflect.ClassTag + import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Duration, StreamingContext} - object KinesisUtils { + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + */ + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T): ReceiverInputDStream[T] = { + val cleanedHandler = ssc.sc.clean(messageHandler) + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, None) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -61,13 +165,12 @@ object KinesisUtils { regionName: String, initialPositionInStream: InitialPositionInStream, checkpointInterval: Duration, - storageLevel: StorageLevel - ): ReceiverInputDStream[Array[Byte]] = { + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - ssc.receiverStream( - new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, checkpointInterval, storageLevel, None)) + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + defaultMessageHandler, None) } } @@ -110,12 +213,12 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream( - new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, checkpointInterval, storageLevel, - Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } } /** @@ -123,12 +226,13 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. + * + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name + * in [[org.apache.spark.SparkConf]]. * * @param ssc StreamingContext object * @param streamName Kinesis stream name @@ -155,9 +259,112 @@ object KinesisUtils { initialPositionInStream: InitialPositionInStream, storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream( - new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), - initialPositionInStream, checkpointInterval, storageLevel, None)) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, + getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName, + checkpointInterval, storageLevel, defaultMessageHandler, None) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + */ + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T]): JavaReceiverInputDStream[T] = { + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey) } /** @@ -197,8 +404,8 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel) + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) } /** @@ -240,10 +447,10 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, + defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } /** @@ -296,6 +503,14 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } } /** diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 87954a31f60c..3f0f6793d2d2 100644 --- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,14 +17,19 @@ package org.apache.spark.streaming.kinesis; +import com.amazonaws.services.kinesis.model.Record; +import org.junit.Test; + +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; -import org.junit.Test; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import java.nio.ByteBuffer; + /** * Demonstrate the use of the KinesisUtils Java API */ @@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { public void testKinesisStream() { // Tests the API, does not actually test data receiving JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); - + + ssc.stop(); + } + + + private static Function handler = new Function() { + @Override + public String call(Record record) { + return record.getPartitionKey() + "-" + record.getSequenceNumber(); + } + }; + + @Test + public void testCustomHandler() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class); + ssc.stop(); } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 000000000000..fdb270eaad8c --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index e81fb11e5959..d85b4cda8ce9 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -22,10 +22,9 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} -class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { +abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) + extends KinesisFunSuite with BeforeAndAfterAll { - private val regionId = "us-east-1" - private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com" private val testData = 1 to 8 private var testUtils: KinesisTestUtils = null @@ -39,13 +38,12 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll private var sc: SparkContext = null private var blockManager: BlockManager = null - override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils(endpointUrl) + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() - shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq @@ -75,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll testIfEnabled("Basic reading from Kinesis") { // Verify all data using multiple ranges in a single RDD partition - val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, - fakeBlockIds(1), + val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(1), Array(SequenceNumberRanges(allRanges.toArray)) ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData1.toSet === testData.toSet) // Verify all data using one range in each of the multiple RDD partitions - val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData2.toSet === testData.toSet) // Verify ordering within each partition - val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collectPartitions() assert(receivedData3.length === allRanges.size) @@ -211,7 +209,8 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll }, "Incorrect configuration of RDD, unexpected ranges set" ) - val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges) + val rdd = new KinesisBackedBlockRDD[Array[Byte]]( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) val collectedData = rdd.map { bytes => new String(bytes).toInt }.collect() @@ -224,8 +223,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll if (testIsBlockValid) { require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") - val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray, - ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) + val rdd2 = new KinesisBackedBlockRDD[Array[Byte]]( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, + isBlockIdValid = Array.fill(blockIds.length)(false)) intercept[SparkException] { rdd2.collect() } @@ -247,3 +247,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll Array.tabulate(num) { i => new StreamBlockId(0, i) } } } + +class WithAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = true) + +class WithoutAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala new file mode 100644 index 000000000000..645e64a0bc3a --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.util.concurrent.{TimeoutException, ExecutorService} + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.ManualClock + +class KinesisCheckpointerSuite extends TestSuiteBase + with MockitoSugar + with BeforeAndAfterEach + with PrivateMethodTester + with Eventually { + + private val workerId = "dummyWorkerId" + private val shardId = "dummyShardId" + private val seqNum = "123" + private val otherSeqNum = "245" + private val checkpointInterval = Duration(10) + private val someSeqNum = Some(seqNum) + private val someOtherSeqNum = Some(otherSeqNum) + + private var receiverMock: KinesisReceiver[Array[Byte]] = _ + private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var kinesisCheckpointer: KinesisCheckpointer = _ + private var clock: ManualClock = _ + + private val checkpoint = PrivateMethod[Unit]('checkpoint) + + override def beforeEach(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + clock = new ManualClock() + kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + } + + test("checkpoint is not called twice for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("checkpoint is called after sequence number increases") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(5 * checkpointInterval.milliseconds) + + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds / 2) + + verify(checkpointerMock, never()).checkpoint(anyString()) + } + + test("should not checkpoint for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + + clock.advance(checkpointInterval.milliseconds * 5) + eventually(timeout(1 second)) { + verify(checkpointerMock, atMost(1)).checkpoint(anyString()) + } + } + + test("removing checkpointer checkpoints one last time") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("if checkpointing is going on, wait until finished before removing and checkpointing") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2) + } + }) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + // don't block test thread + val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))( + ExecutionContext.global) + + intercept[TimeoutException] { + Await.ready(f, 50 millis) + } + + clock.advance(checkpointInterval.milliseconds / 2) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(2)).checkpoint(anyString()) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index 8373138785a8..ee428f31d6ce 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -31,7 +31,7 @@ trait KinesisFunSuite extends SparkFunSuite { if (shouldRunTests) { test(testName)(testBody) } else { - ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) + ignore(s"$testName [enable by setting env var $envVarNameForEnablingTests=1]")(testBody) } } @@ -40,7 +40,7 @@ trait KinesisFunSuite extends SparkFunSuite { if (shouldRunTests) { body } else { - ignore(s"$message [enable by setting env var $envVarName=1]")() + ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")() } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 98f2c7c4f1bf..e5c70db554a2 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays -import scala.collection.JavaConversions.seqAsJavaList - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfter, Matchers} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -44,33 +44,22 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val endpoint = "endpoint-url" val workerId = "dummyWorkerId" val shardId = "dummyShardId" + val seqNum = "dummySeqNum" + val checkpointInterval = Duration(10) + val someSeqNum = Some(seqNum) val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) - val batch = List[Record](record1, record2) + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) - var receiverMock: KinesisReceiver = _ + var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ - var checkpointClockMock: ManualClock = _ - var checkpointStateMock: KinesisCheckpointState = _ - var currentClockMock: Clock = _ override def beforeFunction(): Unit = { - receiverMock = mock[KinesisReceiver] + receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] - checkpointClockMock = mock[ManualClock] - checkpointStateMock = mock[KinesisCheckpointState] - currentClockMock = mock[Clock] - } - - override def afterFunction(): Unit = { - super.afterFunction() - // Since this suite was originally written using EasyMock, add this to preserve the old - // mocking semantics (see SPARK-5735 for more details) - verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, - checkpointStateMock, currentClockMock) } test("check serializability of SerializableAWSCredentials") { @@ -78,99 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft Utils.serialize(new SerializableAWSCredentials("x", "y"))) } - test("process records including store and checkpoint") { + test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) - when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).store(record1.getData().array()) - verify(receiverMock, times(1)).store(record2.getData().array()) - verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint() - verify(checkpointStateMock, times(1)).advanceCheckpoint() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) } - test("shouldn't store and checkpoint when receiver is stopped") { + test("shouldn't store and update checkpointer when receiver is stopped") { when(receiverMock.isStopped()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() + verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("shouldn't checkpoint when exception occurs during store") { + test("shouldn't update checkpointer when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + when( + receiverMock.addRecords(anyString, anyListOf(classOf[Record])) + ).thenThrow(new RuntimeException()) intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) } verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).store(record1.getData().array()) - } - - test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should checkpoint if we have exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("should add to time when advancing checkpoint") { - when(currentClockMock.getTimeMillis()).thenReturn(0) + test("shutdown should checkpoint if the reason is TERMINATE") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - verify(currentClockMock, times(1)).getTimeMillis() + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } - test("shutdown should checkpoint if the reason is TERMINATE") { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - - verify(checkpointerMock, times(1)).checkpoint() - } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) recordProcessor.shutdown(checkpointerMock, null) - verify(checkpointerMock, never()).checkpoint() + verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), + meq[IRecordProcessorCheckpointer](null)) } test("retry success on first attempt") { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index b88c9c6478d5..6fe24fe81165 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -22,34 +22,70 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random +import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record +import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisTestUtils._ +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} -class KinesisStreamSuite extends KinesisFunSuite +abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { - // This is the name that KCL uses to save metadata to DynamoDB - private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + // This is the name that KCL will use to save metadata to DynamoDB + private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + private val batchDuration = Seconds(1) - private var ssc: StreamingContext = _ - private var sc: SparkContext = _ + // Dummy parameters for API testing + private val dummyEndpointUrl = defaultEndpointUrl + private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyAWSAccessKey = "dummyAccessKey" + private val dummyAWSSecretKey = "dummySecretKey" + + private var testUtils: KinesisTestUtils = null + private var ssc: StreamingContext = null + private var sc: SparkContext = null override def beforeAll(): Unit = { val conf = new SparkConf() .setMaster("local[4]") .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name sc = new SparkContext(conf) + + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KPLBasedKinesisTestUtils() + testUtils.createStream() + } } override def afterAll(): Unit = { - sc.stop() - // Delete the Kinesis stream as well as the DynamoDB table generated by - // Kinesis Client Library when consuming the stream + if (ssc != null) { + ssc.stop() + } + if (sc != null) { + sc.stop() + } + if (testUtils != null) { + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + testUtils.deleteStream() + testUtils.deleteDynamoDBTable(appName) + } + } + + before { + ssc = new StreamingContext(sc, batchDuration) } after { @@ -57,21 +93,75 @@ class KinesisStreamSuite extends KinesisFunSuite ssc.stop(stopSparkContext = false) ssc = null } + if (testUtils != null) { + testUtils.deleteDynamoDBTable(appName) + } } test("KinesisUtils API") { - ssc = new StreamingContext(sc, Seconds(1)) // Tests the API, does not actually test data receiving val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + dummyEndpointUrl, Seconds(2), InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") + dummyAWSAccessKey, dummyAWSSecretKey) + } + + test("RDD generation") { + val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), + StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) + assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]]) + + val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]] + val time = Time(1000) + + // Generate block info data for testing + val seqNumRanges1 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + val blockId1 = StreamBlockId(kinesisStream.id, 123) + val blockInfo1 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) + + val seqNumRanges2 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + val blockId2 = StreamBlockId(kinesisStream.id, 345) + val blockInfo2 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) + + // Verify that the generated KinesisBackedBlockRDD has the all the right information + val blockInfos = Seq(blockInfo1, blockInfo2) + val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + assert(kinesisRDD.regionName === dummyRegionName) + assert(kinesisRDD.endpointUrl === dummyEndpointUrl) + assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) + assert(kinesisRDD.awsCredentialsOption === + Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(nonEmptyRDD.partitions.size === blockInfos.size) + nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } + val partitions = nonEmptyRDD.partitions.map { + _.asInstanceOf[KinesisBackedBlockRDDPartition] }.toSeq + assert(partitions.map { _.seqNumberRanges } === Seq(seqNumRanges1, seqNumRanges2)) + assert(partitions.map { _.blockId } === Seq(blockId1, blockId2)) + assert(partitions.forall { _.isBlockIdValid === true }) + + // Verify that KinesisBackedBlockRDD is generated even when there are no blocks + val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) + emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + emptyRDD.partitions shouldBe empty + + // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid + blockInfos.foreach { _.setBlockIdInvalid() } + kinesisStream.createBlockRDD(time, blockInfos).partitions.foreach { partition => + assert(partition.asInstanceOf[KinesisBackedBlockRDDPartition].isBlockIdValid === false) + } } @@ -84,32 +174,120 @@ class KinesisStreamSuite extends KinesisFunSuite * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ testIfEnabled("basic operation") { - val kinesisTestUtils = new KinesisTestUtils() - try { - kinesisTestUtils.createStream() - ssc = new StreamingContext(sc, Seconds(1)) - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, - kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - - val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] - stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => - collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) - } - ssc.start() + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } + ssc.start() - val testData = 1 to 10 - eventually(timeout(120 seconds), interval(10 second)) { - kinesisTestUtils.pushData(testData) - assert(collected === testData.toSet, "\nData received does not match data sent") + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData, aggregateTestData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + + testIfEnabled("custom message handling") { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + stream shouldBe a [ReceiverInputDStream[Int]] + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData, aggregateTestData) + val modData = testData.map(_ + 5) + assert(collected === modData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + + testIfEnabled("failure recovery") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + val checkpointDir = Utils.createTempDir().getAbsolutePath + + ssc = new StreamingContext(sc, Milliseconds(1000)) + ssc.checkpoint(checkpointDir) + + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + + val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch + kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq + collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + }) + + ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint + ssc.start() + + def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty) + + def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty + + // Run until there are at least 10 batches with some data in them + // If this times out because numBatchesWithData is empty, then its likely that foreachRDD + // function failed with exceptions, and nothing got added to `collectedData` + eventually(timeout(2 minutes), interval(1 seconds)) { + testUtils.pushData(1 to 5, aggregateTestData) + assert(isCheckpointPresent && numBatchesWithData > 10) + } + ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused + + // Restart the context from checkpoint and verify whether the + logInfo("Restarting from checkpoint") + ssc = new StreamingContext(checkpointDir) + ssc.start() + val recoveredKinesisStream = ssc.graph.getInputStreams().head + + // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges + // and return the same data + val times = collectedData.keySet + times.foreach { time => + val (arrayOfSeqNumRanges, data) = collectedData(time) + val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] + rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + + // Verify the recovered sequence ranges + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) + arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => + assert(expected.ranges.toSeq === found.ranges.toSeq) } - ssc.stop() - } finally { - kinesisTestUtils.deleteStream() - kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + // Verify the recovered data + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) } + ssc.stop() } } + +class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true) + +class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false) diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f..87a4f05a0596 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795..8cd66c5b2e82 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -47,6 +47,10 @@ test-jar test + + org.apache.xbean + xbean-asm5-shaded + com.google.guava guava @@ -66,6 +70,10 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/java/org/apache/spark/graphx/TripletFields.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java rename to graphx/src/main/java/org/apache/spark/graphx/TripletFields.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java rename to graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 4611a3ace219..ee7302a1edbf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -38,8 +38,8 @@ import org.apache.spark.graphx.impl.EdgeRDDImpl * `impl.ReplicatedVertexView`. */ abstract class EdgeRDD[ED]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index db73a8abc573..869caa340f52 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -46,7 +46,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @note vertex ids are unique. * @return an RDD containing the vertices in this graph */ - @transient val vertices: VertexRDD[VD] + val vertices: VertexRDD[VD] /** * An RDD containing the edges and their associated attributes. The entries in the RDD contain @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -77,7 +77,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum * }}} */ - @transient val triplets: RDD[EdgeTriplet[VD, ED]] + val triplets: RDD[EdgeTriplet[VD, ED]] /** * Caches the vertices and edges associated with this graph at the specified storage level, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9451ff1e5c0e..9827dfab8684 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -282,7 +282,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Convert bi-directional edges into uni-directional ones. * Some graph algorithms (e.g., TriangleCount) assume that an input graph * has its edges in canonical direction. - * This function rewrites the vertex ids of edges so that srcIds are bigger + * This function rewrites the vertex ids of edges so that srcIds are smaller * than dstIds, and merges the duplicated edges. * * @param mergeFunc the user defined reduce function which should diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index a9f04b559c3d..1ef7a78fbcd0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -55,8 +55,8 @@ import org.apache.spark.graphx.impl.VertexRDDImpl * @tparam VD the vertex attribute associated with each vertex in the set. */ abstract class VertexRDD[VD]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { implicit protected def vdTag: ClassTag[VD] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index eb3c997e0f3c..4f1260a5a67b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -34,7 +34,7 @@ object RoutingTablePartition { /** * A message from an edge partition to a vertex specifying the position in which the edge * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower - * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + * 30 bits of the Int, and the position is encoded in the upper 2 bits of the Int. */ type RoutingTableMessage = (VertexId, Int) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 33ac7b0ed609..7f4e7e9d79d6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -87,7 +87,7 @@ class VertexRDDImpl[VD] private[graphx] ( /** The number of vertices in the RDD. */ override def count(): Long = { - partitionsRDD.map(_.size).reduce(_ + _) + partitionsRDD.map(_.size.toLong).reduce(_ + _) } override private[graphx] def mapVertexPartitions[VD2: ClassTag]( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 2bcf8684b8b8..a3ad6bed1c99 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 8c0a461e99fa..52b237fc1509 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -104,18 +104,23 @@ object PageRank extends Logging { graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + val personalized = srcId isDefined + val src: VertexId = srcId.getOrElse(-1L) + // Initialize the PageRank graph with each edge attribute having - // weight 1/outDegree and each vertex with attribute 1.0. + // weight 1/outDegree and each vertex with attribute resetProb. + // When running personalized pagerank, only the source vertex + // has an attribute resetProb. All others are set to 0. var rankGraph: Graph[Double, Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values - .mapVertices( (id, attr) => resetProb ) + .mapVertices { (id, attr) => + if (!(id != src && personalized)) resetProb else 0.0 + } - val personalized = srcId isDefined - val src: VertexId = srcId.getOrElse(-1L) def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 @@ -192,6 +197,9 @@ object PageRank extends Logging { graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + val personalized = srcId.isDefined + val src: VertexId = srcId.getOrElse(-1L) + // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. val pagerankGraph: Graph[(Double, Double), Double] = graph @@ -202,13 +210,11 @@ object PageRank extends Logging { // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initalPR, delta = 0) - .mapVertices( (id, attr) => (0.0, 0.0) ) + .mapVertices { (id, attr) => + if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0) + } .cache() - val personalized = srcId.isDefined - val src: VertexId = srcId.getOrElse(-1L) - - // Define the three functions needed to implement PageRank in the GraphX // version of Pregel def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { @@ -225,7 +231,8 @@ object PageRank extends Logging { teleport = oldPR*delta val newPR = teleport + (1.0 - resetProb) * msgSum - (newPR, newPR - oldPR) + val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR + (newPR, newDelta) } def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { @@ -239,7 +246,7 @@ object PageRank extends Logging { def messageCombiner(a: Double, b: Double): Double = a + b // The initial message received by all vertices in PageRank - val initialMessage = resetProb / (1.0 - resetProb) + val initialMessage = if (personalized) 0.0 else resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. val vp = if (personalized) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 74a7de18d416..a6d0cb640966 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,11 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.spark.util.Utils - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm5.Opcodes._ +import org.apache.spark.util.Utils /** * Includes an utility function to test whether a function accesses a specific attribute @@ -107,18 +106,19 @@ private[graphx] object BytecodeUtils { * MethodInvocationFinder("spark/graph/Foo", "test") * its methodsInvoked variable will contain the set of methods invoked directly by * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual metod invoked by inspecting the bytecode. + * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 45f1e3011035..bdff31446f8e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -109,17 +109,22 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * - (nVertices - 1)) )) < 1.0E-5) + val correct = (vid > 0 && pr == 0.0) || + (vid == 0 && pr == resetProb) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + + // We have one outbound edge from 1 to 0 + val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb) + .vertices.cache() + val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache() + assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol) } - } // end of test Star PageRank + } // end of test Star PersonalPageRank test("Grid PageRank") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index c47552cf3a3b..608e43cf3ff5 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -26,7 +26,7 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => - val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) + val rawEdges = sc.parallelize(Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices diff --git a/launcher/pom.xml b/launcher/pom.xml index 2fd768d8119c..5739bfc16958 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -43,13 +43,13 @@ test - junit - junit + org.mockito + mockito-core test - org.mockito - mockito-core + org.slf4j + jul-to-slf4j test @@ -63,6 +63,11 @@ test + + org.apache.spark + spark-test-tags_${scala.binary.version} + + org.apache.hadoop diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 5e793a5c4877..55fe156cf665 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -47,7 +47,7 @@ abstract class AbstractCommandBuilder { String javaHome; String mainClass; String master; - String propertiesFile; + protected String propertiesFile; final List appArgs; final List jars; final List files; @@ -55,6 +55,10 @@ abstract class AbstractCommandBuilder { final Map childEnv; final Map conf; + // The merged configuration for the application. Cached to avoid having to read / parse + // properties files multiple times. + private Map effectiveConfig; + public AbstractCommandBuilder() { this.appArgs = new ArrayList(); this.childEnv = new HashMap(); @@ -116,29 +120,6 @@ List buildJavaCommand(String extraClassPath) throws IOException { return cmd; } - /** - * Adds the default perm gen size option for Spark if the VM requires it and the user hasn't - * set it. - */ - void addPermGenSizeOpt(List cmd) { - // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. - if (getJavaVendor() == JavaVendor.IBM) { - return; - } - String[] version = System.getProperty("java.version").split("\\."); - if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { - return; - } - - for (String arg : cmd) { - if (arg.startsWith("-XX:MaxPermSize=")) { - return; - } - } - - cmd.add("-XX:MaxPermSize=256m"); - } - void addOptionString(List cmd, String options) { if (!isEmpty(options)) { for (String opt : parseOptionString(options)) { @@ -167,11 +148,13 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { - System.err.println( - "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + - "assembly."); + if (!isTesting) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + } for (String project : projects) { addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); @@ -200,7 +183,7 @@ List buildClassPath(String appClassPath) throws IOException { // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. // That duplicates some of the code in the shell scripts that look for the assembly, though. String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + if (assembly == null && !isTesting) { assembly = findAssembly(); } addToClassPath(cp, assembly); @@ -215,12 +198,14 @@ List buildClassPath(String appClassPath) throws IOException { libdir = new File(sparkHome, "lib_managed/jars"); } - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); + if (libdir.isDirectory()) { + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } } + } else { + checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } addToClassPath(cp, getenv("HADOOP_CONF_DIR")); @@ -256,15 +241,15 @@ String getScalaVersion() { return scala; } String sparkHome = getSparkHome(); - File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); - File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + File scala210 = new File(sparkHome, "launcher/target/scala-2.10"); + File scala211 = new File(sparkHome, "launcher/target/scala-2.11"); checkState(!scala210.isDirectory() || !scala211.isDirectory(), "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); if (scala210.isDirectory()) { return "2.10"; } else { - checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + checkState(scala211.isDirectory(), "Cannot find any build directories."); return "2.11"; } } @@ -276,12 +261,34 @@ String getSparkHome() { return path; } + String getenv(String key) { + return firstNonEmpty(childEnv.get(key), System.getenv(key)); + } + + void setPropertiesFile(String path) { + effectiveConfig = null; + this.propertiesFile = path; + } + + Map getEffectiveConfig() throws IOException { + if (effectiveConfig == null) { + effectiveConfig = new HashMap<>(conf); + Properties p = loadPropertiesFile(); + for (String key : p.stringPropertyNames()) { + if (!effectiveConfig.containsKey(key)) { + effectiveConfig.put(key, p.getProperty(key)); + } + } + } + return effectiveConfig; + } + /** * Loads the configuration file for the application, if it exists. This is either the * user-specified properties file, or the spark-defaults.conf file under the Spark configuration * directory. */ - Properties loadPropertiesFile() throws IOException { + private Properties loadPropertiesFile() throws IOException { Properties props = new Properties(); File propsFile; if (propertiesFile != null) { @@ -313,10 +320,6 @@ Properties loadPropertiesFile() throws IOException { return props; } - String getenv(String key) { - return firstNonEmpty(childEnv.get(key), System.getenv(key)); - } - private String findAssembly() { String sparkHome = getSparkHome(); File libdir; diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java new file mode 100644 index 000000000000..1bfda289dec3 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Handle implementation for monitoring apps started as a child process. + */ +class ChildProcAppHandle implements SparkAppHandle { + + private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final ThreadFactory REDIRECTOR_FACTORY = + new NamedThreadFactory("launcher-proc-%d"); + + private final String secret; + private final LauncherServer server; + + private Process childProc; + private boolean disposed; + private LauncherConnection connection; + private List listeners; + private State state; + private String appId; + private OutputRedirector redirector; + + ChildProcAppHandle(String secret, LauncherServer server) { + this.secret = secret; + this.server = server; + this.state = State.UNKNOWN; + } + + @Override + public synchronized void addListener(Listener l) { + if (listeners == null) { + listeners = new ArrayList<>(); + } + listeners.add(l); + } + + @Override + public State getState() { + return state; + } + + @Override + public String getAppId() { + return appId; + } + + @Override + public void stop() { + CommandBuilderUtils.checkState(connection != null, "Application is still not connected."); + try { + connection.send(new LauncherProtocol.Stop()); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + @Override + public synchronized void disconnect() { + if (!disposed) { + disposed = true; + if (connection != null) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. + } + } + server.unregister(this); + if (redirector != null) { + redirector.stop(); + } + } + } + + @Override + public synchronized void kill() { + if (!disposed) { + disconnect(); + } + if (childProc != null) { + try { + childProc.exitValue(); + } catch (IllegalThreadStateException e) { + // Child is still alive. Try to use Java 8's "destroyForcibly()" if available, + // fall back to the old API if it's not there. + try { + Method destroy = childProc.getClass().getMethod("destroyForcibly"); + destroy.invoke(childProc); + } catch (Exception inner) { + childProc.destroy(); + } + } finally { + childProc = null; + } + } + } + + String getSecret() { + return secret; + } + + void setChildProc(Process childProc, String loggerName) { + this.childProc = childProc; + this.redirector = new OutputRedirector(childProc.getInputStream(), loggerName, + REDIRECTOR_FACTORY); + } + + void setConnection(LauncherConnection connection) { + this.connection = connection; + } + + LauncherServer getServer() { + return server; + } + + LauncherConnection getConnection() { + return connection; + } + + void setState(State s) { + if (!state.isFinal()) { + state = s; + fireEvent(false); + } else { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { state, s }); + } + } + + void setAppId(String appId) { + this.appId = appId; + fireEvent(true); + } + + private synchronized void fireEvent(boolean isInfoChanged) { + if (listeners != null) { + for (Listener l : listeners) { + if (isInfoChanged) { + l.infoChanged(this); + } else { + l.stateChanged(this); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index a16c0d2b5ca0..d30c2ec5f87b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -313,4 +313,27 @@ static String quoteForCommandString(String s) { return quoted.append('"').toString(); } + /** + * Adds the default perm gen size option for Spark if the VM requires it and the user hasn't + * set it. + */ + static void addPermGenSizeOpt(List cmd) { + // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. + if (getJavaVendor() == JavaVendor.IBM) { + return; + } + String[] version = System.getProperty("java.version").split("\\."); + if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { + return; + } + + for (String arg : cmd) { + if (arg.startsWith("-XX:MaxPermSize=")) { + return; + } + } + + cmd.add("-XX:MaxPermSize=256m"); + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java new file mode 100644 index 000000000000..eec264909bbb --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.net.Socket; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * Encapsulates a connection between a launcher server and client. This takes care of the + * communication (sending and receiving messages), while processing of messages is left for + * the implementations. + */ +abstract class LauncherConnection implements Closeable, Runnable { + + private static final Logger LOG = Logger.getLogger(LauncherConnection.class.getName()); + + private final Socket socket; + private final ObjectOutputStream out; + + private volatile boolean closed; + + LauncherConnection(Socket socket) throws IOException { + this.socket = socket; + this.out = new ObjectOutputStream(socket.getOutputStream()); + this.closed = false; + } + + protected abstract void handle(Message msg) throws IOException; + + @Override + public void run() { + try { + ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); + while (!closed) { + Message msg = (Message) in.readObject(); + handle(msg); + } + } catch (EOFException eof) { + // Remote side has closed the connection, just cleanup. + try { + close(); + } catch (Exception unused) { + // no-op. + } + } catch (Exception e) { + if (!closed) { + LOG.log(Level.WARNING, "Error in inbound message handling.", e); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + } + } + + protected synchronized void send(Message msg) throws IOException { + try { + CommandBuilderUtils.checkState(!closed, "Disconnected."); + out.writeObject(msg); + out.flush(); + } catch (IOException ioe) { + if (!closed) { + LOG.log(Level.WARNING, "Error when sending message.", ioe); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + throw ioe; + } + } + + @Override + public void close() throws IOException { + if (!closed) { + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java new file mode 100644 index 000000000000..50f136497ec1 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.Socket; +import java.util.Map; + +/** + * Message definitions for the launcher communication protocol. These messages must remain + * backwards-compatible, so that the launcher can talk to older versions of Spark that support + * the protocol. + */ +final class LauncherProtocol { + + /** Environment variable where the server port is stored. */ + static final String ENV_LAUNCHER_PORT = "_SPARK_LAUNCHER_PORT"; + + /** Environment variable where the secret for connecting back to the server is stored. */ + static final String ENV_LAUNCHER_SECRET = "_SPARK_LAUNCHER_SECRET"; + + static class Message implements Serializable { + + } + + /** + * Hello message, sent from client to server. + */ + static class Hello extends Message { + + final String secret; + final String sparkVersion; + + Hello(String secret, String version) { + this.secret = secret; + this.sparkVersion = version; + } + + } + + /** + * SetAppId message, sent from client to server. + */ + static class SetAppId extends Message { + + final String appId; + + SetAppId(String appId) { + this.appId = appId; + } + + } + + /** + * SetState message, sent from client to server. + */ + static class SetState extends Message { + + final SparkAppHandle.State state; + + SetState(SparkAppHandle.State state) { + this.state = state; + } + + } + + /** + * Stop message, send from server to client to stop the application. + */ + static class Stop extends Message { + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java new file mode 100644 index 000000000000..d099ee9aa9da --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.List; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * A server that listens locally for connections from client launched by the library. Each client + * has a secret that it needs to send to the server to identify itself and establish the session. + * + * I/O is currently blocking (one thread per client). Clients have a limited time to connect back + * to the server, otherwise the server will ignore the connection. + * + * === Architecture Overview === + * + * The launcher server is used when Spark apps are launched as separate processes than the calling + * app. It looks more or less like the following: + * + * ----------------------- ----------------------- + * | User App | spark-submit | Spark App | + * | | -------------------> | | + * | ------------| |------------- | + * | | | hello | | | + * | | L. Server |<----------------------| L. Backend | | + * | | | | | | + * | ------------- ----------------------- + * | | | ^ + * | v | | + * | -------------| | + * | | | | + * | | App Handle |<------------------------------ + * | | | + * ----------------------- + * + * The server is started on demand and remains active while there are active or outstanding clients, + * to avoid opening too many ports when multiple clients are launched. Each client is given a unique + * secret, and have a limited amount of time to connect back + * ({@link SparkLauncher#CHILD_CONNECTION_TIMEOUT}), at which point the server will throw away + * that client's state. A client is only allowed to connect back to the server once. + * + * The launcher server listens on the localhost only, so it doesn't need access controls (aside from + * the per-app secret) nor encryption. It thus requires that the launched app has a local process + * that communicates with the server. In cluster mode, this means that the client that launches the + * application must remain alive for the duration of the application (or until the app handle is + * disconnected). + */ +class LauncherServer implements Closeable { + + private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName()); + private static final String THREAD_NAME_FMT = "LauncherServer-%d"; + private static final long DEFAULT_CONNECT_TIMEOUT = 10000L; + + /** For creating secrets used for communication with child processes. */ + private static final SecureRandom RND = new SecureRandom(); + + private static volatile LauncherServer serverInstance; + + /** + * Creates a handle for an app to be launched. This method will start a server if one hasn't been + * started yet. The server is shared for multiple handles, and once all handles are disposed of, + * the server is shut down. + */ + static synchronized ChildProcAppHandle newAppHandle() throws IOException { + LauncherServer server = serverInstance != null ? serverInstance : new LauncherServer(); + server.ref(); + serverInstance = server; + + String secret = server.createSecret(); + while (server.pending.containsKey(secret)) { + secret = server.createSecret(); + } + + return server.newAppHandle(secret); + } + + static LauncherServer getServerInstance() { + return serverInstance; + } + + private final AtomicLong refCount; + private final AtomicLong threadIds; + private final ConcurrentMap pending; + private final List clients; + private final ServerSocket server; + private final Thread serverThread; + private final ThreadFactory factory; + private final Timer timeoutTimer; + + private volatile boolean running; + + private LauncherServer() throws IOException { + this.refCount = new AtomicLong(0); + + ServerSocket server = new ServerSocket(); + try { + server.setReuseAddress(true); + server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + + this.clients = new ArrayList(); + this.threadIds = new AtomicLong(); + this.factory = new NamedThreadFactory(THREAD_NAME_FMT); + this.pending = new ConcurrentHashMap<>(); + this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true); + this.server = server; + this.running = true; + + this.serverThread = factory.newThread(new Runnable() { + @Override + public void run() { + acceptConnections(); + } + }); + serverThread.start(); + } catch (IOException ioe) { + close(); + throw ioe; + } catch (Exception e) { + close(); + throw new IOException(e); + } + } + + /** + * Creates a new app handle. The handle will wait for an incoming connection for a configurable + * amount of time, and if one doesn't arrive, it will transition to an error state. + */ + ChildProcAppHandle newAppHandle(String secret) { + ChildProcAppHandle handle = new ChildProcAppHandle(secret, this); + ChildProcAppHandle existing = pending.putIfAbsent(secret, handle); + CommandBuilderUtils.checkState(existing == null, "Multiple handles with the same secret."); + return handle; + } + + @Override + public void close() throws IOException { + synchronized (this) { + if (running) { + running = false; + timeoutTimer.cancel(); + server.close(); + synchronized (clients) { + List copy = new ArrayList<>(clients); + clients.clear(); + for (ServerConnection client : copy) { + client.close(); + } + } + } + } + if (serverThread != null) { + try { + serverThread.join(); + } catch (InterruptedException ie) { + // no-op + } + } + } + + void ref() { + refCount.incrementAndGet(); + } + + void unref() { + synchronized(LauncherServer.class) { + if (refCount.decrementAndGet() == 0) { + try { + close(); + } catch (IOException ioe) { + // no-op. + } finally { + serverInstance = null; + } + } + } + } + + int getPort() { + return server.getLocalPort(); + } + + /** + * Removes the client handle from the pending list (in case it's still there), and unrefs + * the server. + */ + void unregister(ChildProcAppHandle handle) { + pending.remove(handle.getSecret()); + unref(); + } + + private void acceptConnections() { + try { + while (running) { + final Socket client = server.accept(); + TimerTask timeout = new TimerTask() { + @Override + public void run() { + LOG.warning("Timed out waiting for hello message from client."); + try { + client.close(); + } catch (IOException ioe) { + // no-op. + } + } + }; + ServerConnection clientConnection = new ServerConnection(client, timeout); + Thread clientThread = factory.newThread(clientConnection); + synchronized (timeout) { + clientThread.start(); + synchronized (clients) { + clients.add(clientConnection); + } + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, getConnectionTimeout()); + } else { + timeout.run(); + } + } + } + } catch (IOException ioe) { + if (running) { + LOG.log(Level.SEVERE, "Error in accept loop.", ioe); + } + } + } + + private long getConnectionTimeout() { + String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT; + } + + private String createSecret() { + byte[] secret = new byte[128]; + RND.nextBytes(secret); + + StringBuilder sb = new StringBuilder(); + for (byte b : secret) { + int ival = b >= 0 ? b : Byte.MAX_VALUE - b; + if (ival < 0x10) { + sb.append("0"); + } + sb.append(Integer.toHexString(ival)); + } + return sb.toString(); + } + + private class ServerConnection extends LauncherConnection { + + private TimerTask timeout; + private ChildProcAppHandle handle; + + ServerConnection(Socket socket, TimerTask timeout) throws IOException { + super(socket); + this.timeout = timeout; + } + + @Override + protected void handle(Message msg) throws IOException { + try { + if (msg instanceof Hello) { + synchronized (timeout) { + timeout.cancel(); + } + timeout = null; + Hello hello = (Hello) msg; + ChildProcAppHandle handle = pending.remove(hello.secret); + if (handle != null) { + handle.setState(SparkAppHandle.State.CONNECTED); + handle.setConnection(this); + this.handle = handle; + } else { + throw new IllegalArgumentException("Received Hello for unknown client."); + } + } else { + if (handle == null) { + throw new IllegalArgumentException("Expected hello, got: " + + msg != null ? msg.getClass().getName() : null); + } + if (msg instanceof SetAppId) { + SetAppId set = (SetAppId) msg; + handle.setAppId(set.appId); + } else if (msg instanceof SetState) { + handle.setState(((SetState)msg).state); + } else { + throw new IllegalArgumentException("Invalid message: " + + msg != null ? msg.getClass().getName() : null); + } + } + } catch (Exception e) { + LOG.log(Level.INFO, "Error handling message from client.", e); + if (timeout != null) { + timeout.cancel(); + } + close(); + } finally { + timeoutTimer.purge(); + } + } + + @Override + public void close() throws IOException { + synchronized (clients) { + clients.remove(this); + } + super.close(); + if (handle != null) { + handle.disconnect(); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3b..a4e3acc674f3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -32,7 +32,7 @@ class Main { /** * Usage: Main [class] [class args] - *

    + *

    * This CLI works in two different modes: *

      *
    • "spark-submit": if class is "org.apache.spark.deploy.SparkSubmit", the @@ -42,7 +42,7 @@ class Main { * * This class works in tandem with the "bin/spark-class" script on Unix-like systems, and * "bin/spark-class2.cmd" batch script on Windows to execute the final command. - *

      + *

      * On Unix-like systems, the output is a list of command arguments, separated by the NULL * character. On Windows, the output is a command line suitable for direct execution from the * script. diff --git a/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java new file mode 100644 index 000000000000..995f4d73daaa --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +class NamedThreadFactory implements ThreadFactory { + + private final String nameFormat; + private final AtomicLong threadIds; + + NamedThreadFactory(String nameFormat) { + this.nameFormat = nameFormat; + this.threadIds = new AtomicLong(); + } + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, String.format(nameFormat, threadIds.incrementAndGet())); + t.setDaemon(true); + return t; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java new file mode 100644 index 000000000000..6e7120167d60 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Redirects lines read from a given input stream to a j.u.l.Logger (at INFO level). + */ +class OutputRedirector { + + private final BufferedReader reader; + private final Logger sink; + private final Thread thread; + + private volatile boolean active; + + OutputRedirector(InputStream in, ThreadFactory tf) { + this(in, OutputRedirector.class.getName(), tf); + } + + OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { + this.active = true; + this.reader = new BufferedReader(new InputStreamReader(in)); + this.thread = tf.newThread(new Runnable() { + @Override + public void run() { + redirect(); + } + }); + this.sink = Logger.getLogger(loggerName); + thread.start(); + } + + private void redirect() { + try { + String line; + while ((line = reader.readLine()) != null) { + if (active) { + sink.info(line.replaceFirst("\\s*$", "")); + } + } + } catch (IOException e) { + sink.log(Level.FINE, "Error reading child process output.", e); + } + } + + /** + * This method just stops the output of the process from showing up in the local logs. + * The child's output will still be read (and, thus, the redirect thread will still be + * alive) to avoid the child process hanging because of lack of output buffer. + */ + void stop() { + active = false; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java new file mode 100644 index 000000000000..e9caf0b3cb06 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +/** + * A handle to a running Spark application. + *

      + * Provides runtime information about the underlying Spark application, and actions to control it. + * + * @since 1.6.0 + */ +public interface SparkAppHandle { + + /** + * Represents the application's state. A state can be "final", in which case it will not change + * after it's reached, and means the application is not running anymore. + * + * @since 1.6.0 + */ + public enum State { + /** The application has not reported back yet. */ + UNKNOWN(false), + /** The application has connected to the handle. */ + CONNECTED(false), + /** The application has been submitted to the cluster. */ + SUBMITTED(false), + /** The application is running. */ + RUNNING(false), + /** The application finished with a successful status. */ + FINISHED(true), + /** The application finished with a failed status. */ + FAILED(true), + /** The application was killed. */ + KILLED(true); + + private final boolean isFinal; + + State(boolean isFinal) { + this.isFinal = isFinal; + } + + /** + * Whether this state is a final state, meaning the application is not running anymore + * once it's reached. + */ + public boolean isFinal() { + return isFinal; + } + } + + /** + * Adds a listener to be notified of changes to the handle's information. Listeners will be called + * from the thread processing updates from the application, so they should avoid blocking or + * long-running operations. + * + * @param l Listener to add. + */ + void addListener(Listener l); + + /** Returns the current application state. */ + State getState(); + + /** Returns the application ID, or null if not yet known. */ + String getAppId(); + + /** + * Asks the application to stop. This is best-effort, since the application may fail to receive + * or act on the command. Callers should watch for a state transition that indicates the + * application has really stopped. + */ + void stop(); + + /** + * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send + * a {@link #stop()} message to the application, so it's recommended that users first try to + * stop the application cleanly and only resort to this method if that fails. + *

      + * Note that if the application is running as a child process, this method fail to kill the + * process when using Java 7. This may happen if, for example, the application is deadlocked. + */ + void kill(); + + /** + * Disconnects the handle from the application, without stopping it. After this method is called, + * the handle will not be able to communicate with the application anymore. + */ + void disconnect(); + + /** + * Listener for updates to a handle's state. The callbacks do not receive information about + * what exactly has changed, just that an update has occurred. + * + * @since 1.6.0 + */ + public interface Listener { + + /** + * Callback for changes in the handle's state. + * + * @param handle The updated handle. + * @see SparkAppHandle#getState() + */ + void stateChanged(SparkAppHandle handle); + + /** + * Callback for changes in any information that is not the handle's state. + * + * @param handle The updated handle. + */ + void infoChanged(SparkAppHandle handle); + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 5f95e2c74f90..931a24cfd4b1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -28,7 +28,7 @@ /** * Command builder for internal Spark classes. - *

      + *

      * This class handles building the command to launch all internal Spark classes except for * SparkSubmit (which is handled by {@link SparkSubmitCommandBuilder} class. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index c0f89c923069..20e6003a00c1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,12 +20,15 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. *

      * Use this class to start Spark applications programmatically. The class uses a builder pattern @@ -37,6 +40,9 @@ public class SparkLauncher { /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; + /** The Spark deploy mode. */ + public static final String DEPLOY_MODE = "spark.submit.deployMode"; + /** Configuration key for the driver memory. */ public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ @@ -57,7 +63,35 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + /** Logger name to use when launching a child process. */ + public static final String CHILD_PROCESS_LOGGER_NAME = "spark.launcher.childProcLoggerName"; + + /** + * Maximum time (in ms) to wait for a child process to connect back to the launcher server + * when using @link{#start()}. + */ + public static final String CHILD_CONNECTION_TIMEOUT = "spark.launcher.childConectionTimeout"; + + /** Used internally to create unique logger names. */ + private static final AtomicInteger COUNTER = new AtomicInteger(); + + static final Map launcherConfig = new HashMap(); + + /** + * Set a configuration value for the launcher library. These config values do not affect the + * launched application, but rather the behavior of the launcher library itself when managing + * applications. + * + * @since 1.6.0 + * @param name Config name. + * @param value Config value. + */ + public static void setConfig(String name, String value) { + launcherConfig.put(name, value); + } + + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -107,7 +141,7 @@ public SparkLauncher setSparkHome(String sparkHome) { */ public SparkLauncher setPropertiesFile(String path) { checkNotNull(path, "path"); - builder.propertiesFile = path; + builder.setPropertiesFile(path); return this; } @@ -187,6 +221,75 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

      + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @since 1.5.0 + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

      + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

      + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @since 1.5.0 + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -250,10 +353,81 @@ public SparkLauncher setVerbose(boolean verbose) { /** * Launches a sub-process that will start the configured Spark application. + *

      + * The {@link #startApplication(SparkAppHandle.Listener...)} method is preferred when launching + * Spark, since it provides better control of the child application. * * @return A process handle for the Spark app. */ public Process launch() throws IOException { + return createBuilder().start(); + } + + /** + * Starts a Spark application. + *

      + * This method returns a handle that provides information about the running application and can + * be used to do basic interaction with it. + *

      + * The returned handle assumes that the application will instantiate a single SparkContext + * during its lifetime. Once that context reports a final state (one that indicates the + * SparkContext has stopped), the handle will not perform new state transitions, so anything + * that happens after that cannot be monitored. If the underlying application is launched as + * a child process, {@link SparkAppHandle#kill()} can still be used to kill the child process. + *

      + * Currently, all applications are launched as child processes. The child's stdout and stderr + * are merged and written to a logger (see java.util.logging). The logger's name + * can be defined by setting {@link #CHILD_PROCESS_LOGGER_NAME} in the app's configuration. If + * that option is not set, the code will try to derive a name from the application's name or + * main class / script file. If those cannot be determined, an internal, unique name will be + * used. In all cases, the logger name will start with "org.apache.spark.launcher.app", to fit + * more easily into the configuration of commonly-used logging systems. + * + * @since 1.6.0 + * @param listeners Listeners to add to the handle before the app is launched. + * @return A handle for the launched application. + */ + public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) throws IOException { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + for (SparkAppHandle.Listener l : listeners) { + handle.addListener(l); + } + + String appName = builder.getEffectiveConfig().get(CHILD_PROCESS_LOGGER_NAME); + if (appName == null) { + if (builder.appName != null) { + appName = builder.appName; + } else if (builder.mainClass != null) { + int dot = builder.mainClass.lastIndexOf("."); + if (dot >= 0 && dot < builder.mainClass.length() - 1) { + appName = builder.mainClass.substring(dot + 1, builder.mainClass.length()); + } else { + appName = builder.mainClass; + } + } else if (builder.appResource != null) { + appName = new File(builder.appResource).getName(); + } else { + appName = String.valueOf(COUNTER.incrementAndGet()); + } + } + + String loggerPrefix = getClass().getPackage().getName(); + String loggerName = String.format("%s.app.%s", loggerPrefix, appName); + ProcessBuilder pb = createBuilder().redirectErrorStream(true); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_PORT, + String.valueOf(LauncherServer.getServerInstance().getPort())); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_SECRET, handle.getSecret()); + try { + handle.setChildProc(pb.start(), loggerName); + } catch (IOException ioe) { + handle.kill(); + throw ioe; + } + + return handle; + } + + private ProcessBuilder createBuilder() { List cmd = new ArrayList(); String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); @@ -274,7 +448,35 @@ public Process launch() throws IOException { for (Map.Entry e : builder.childEnv.entrySet()) { pb.environment().put(e.getKey(), e.getValue()); } - return pb.start(); + return pb; } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 87c43aa9980e..a95f0f17517d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -25,11 +25,11 @@ /** * Special command builder for handling a CLI invocation of SparkSubmit. - *

      + *

      * This builder adds command line parsing compatible with SparkSubmit. It handles setting * driver-side options and special parsing behavior needed for the special-casing certain internal * Spark applications. - *

      + *

      * This class has also some special features to aid launching pyspark. */ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { @@ -76,8 +76,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; - private final boolean printHelp; + final List sparkArgs; + private final boolean printInfo; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -88,7 +88,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); - this.printHelp = false; + this.printInfo = false; } SparkSubmitCommandBuilder(List args) { @@ -108,14 +108,14 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { OptionParser parser = new OptionParser(); parser.parse(submitArgs); - this.printHelp = parser.helpRequested; + this.printInfo = parser.infoRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -188,10 +188,9 @@ private List buildSparkSubmitCommand(Map env) throws IOE // Load the properties file and check whether spark-submit will be running the app's driver // or just launching a cluster app. When running the driver, the JVM's argument will be // modified to cover the driver's configuration. - Properties props = loadPropertiesFile(); - boolean isClientMode = isClientMode(props); - String extraClassPath = isClientMode ? - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; + Map config = getEffectiveConfig(); + boolean isClientMode = isClientMode(config); + String extraClassPath = isClientMode ? config.get(SparkLauncher.DRIVER_EXTRA_CLASSPATH) : null; List cmd = buildJavaCommand(extraClassPath); // Take Thrift Server as daemon @@ -212,14 +211,13 @@ private List buildSparkSubmitCommand(Map env) throws IOE // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; - String memory = firstNonEmpty(tsMemory, - firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); - addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props)); + addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } addPermGenSizeOpt(cmd); @@ -281,9 +279,8 @@ private List buildSparkRCommand(Map env) throws IOExcept private void constructEnvVarArgs( Map env, String submitArgsEnvVariable) throws IOException { - Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { @@ -295,13 +292,13 @@ private void constructEnvVarArgs( env.put(submitArgsEnvVariable, submitArgs.toString()); } - - private boolean isClientMode(Properties userProps) { - String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); - // Default master is "local[*]", so assume client mode in that case. + private boolean isClientMode(Map userProps) { + String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); + String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); + // Default master is "local[*]", so assume client mode in that case return userMaster == null || - "client".equals(deployMode) || - (!userMaster.equals("yarn-cluster") && deployMode == null); + "client".equals(userDeployMode) || + (!userMaster.equals("yarn-cluster") && userDeployMode == null); } /** @@ -315,7 +312,7 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { - boolean helpRequested = false; + boolean infoRequested = false; @Override protected boolean handle(String opt, String value) { @@ -348,7 +345,10 @@ protected boolean handle(String opt, String value) { appResource = specialClasses.get(value); } } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { - helpRequested = true; + infoRequested = true; + sparkArgs.add(opt); + } else if (opt.equals(VERSION)) { + infoRequested = true; sparkArgs.add(opt); } else { sparkArgs.add(opt); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index b88bba883ac6..6767cc507964 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -23,7 +23,7 @@ /** * Parser for spark-submit command line options. - *

      + *

      * This class encapsulates the parsing code for spark-submit command line options, so that there * is a single list of options that needs to be maintained (well, sort of, but it makes it harder * to break things). @@ -51,6 +51,7 @@ class SparkSubmitOptionParser { protected final String MASTER = "--master"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; + protected final String PACKAGES_EXCLUDE = "--exclude-packages"; protected final String PROPERTIES_FILE = "--properties-file"; protected final String PROXY_USER = "--proxy-user"; protected final String PY_FILES = "--py-files"; @@ -79,10 +80,10 @@ class SparkSubmitOptionParser { * This is the canonical list of spark-submit options. Each entry in the array contains the * different aliases for the same option; the first element of each entry is the "official" * name of the option, passed to {@link #handle(String, String)}. - *

      + *

      * Options not listed here nor in the "switch" list below will result in a call to * {@link $#handleUnknown(String)}. - *

      + *

      * These two arrays are visible for tests. */ final String[][] opts = { @@ -105,6 +106,7 @@ class SparkSubmitOptionParser { { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PACKAGES_EXCLUDE }, { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, @@ -128,7 +130,7 @@ class SparkSubmitOptionParser { /** * Parse a list of spark-submit command line options. - *

      + *

      * See SparkSubmitArguments.scala for a more formal description of available options. * * @throws IllegalArgumentException If an error is found during parsing. diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7c97dba511b2..d1ac39bdc76a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,17 +17,42 @@ /** * Library for launching Spark applications. - * + * *

      * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. *

      * *

      - * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} - * and configure the application to run. For example: + * The {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} can be used to start Spark and provide + * a handle to monitor and control the running application: *

      - * + * + *
      + * {@code
      + *   import org.apache.spark.launcher.SparkAppHandle;
      + *   import org.apache.spark.launcher.SparkLauncher;
      + *
      + *   public class MyLauncher {
      + *     public static void main(String[] args) throws Exception {
      + *       SparkAppHandle handle = new SparkLauncher()
      + *         .setAppResource("/my/app.jar")
      + *         .setMainClass("my.spark.app.Main")
      + *         .setMaster("local")
      + *         .setConf(SparkLauncher.DRIVER_MEMORY, "2g")
      + *         .startApplication();
      + *       // Use handle API to monitor / control application.
      + *     }
      + *   }
      + * }
      + * 
      + * + *

      + * It's also possible to launch a raw child process, using the + * {@link org.apache.spark.launcher.SparkLauncher#launch()} method: + *

      + * *
        * {@code
        *   import org.apache.spark.launcher.SparkLauncher;
      @@ -45,5 +70,10 @@
        *   }
        * }
        * 
      + * + *

      This method requires the calling code to manually manage the child process, including its + * output streams (to avoid possible deadlocks). It's recommended that + * {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} be used instead.

      */ package org.apache.spark.launcher; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java new file mode 100644 index 000000000000..23e2c64d6dcd --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import org.slf4j.bridge.SLF4JBridgeHandler; + +/** + * Handles configuring the JUL -> SLF4J bridge. + */ +class BaseSuite { + + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java new file mode 100644 index 000000000000..dc8fbb58d880 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +public class LauncherServerSuite extends BaseSuite { + + @Test + public void testLauncherServerReuse() throws Exception { + ChildProcAppHandle handle1 = null; + ChildProcAppHandle handle2 = null; + ChildProcAppHandle handle3 = null; + + try { + handle1 = LauncherServer.newAppHandle(); + handle2 = LauncherServer.newAppHandle(); + LauncherServer server1 = handle1.getServer(); + assertSame(server1, handle2.getServer()); + + handle1.kill(); + handle2.kill(); + + handle3 = LauncherServer.newAppHandle(); + assertNotSame(server1, handle3.getServer()); + + handle3.kill(); + + assertNull(LauncherServer.getServerInstance()); + } finally { + kill(handle1); + kill(handle2); + kill(handle3); + } + } + + @Test + public void testCommunication() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + + final Object waitLock = new Object(); + handle.addListener(new SparkAppHandle.Listener() { + @Override + public void stateChanged(SparkAppHandle handle) { + wakeUp(); + } + + @Override + public void infoChanged(SparkAppHandle handle) { + wakeUp(); + } + + private void wakeUp() { + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + }); + + client = new TestClient(s); + synchronized (waitLock) { + client.send(new Hello(handle.getSecret(), "1.4.0")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + + // Make sure the server matched the client to the handle. + assertNotNull(handle.getConnection()); + + synchronized (waitLock) { + client.send(new SetAppId("app-id")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals("app-id", handle.getAppId()); + + synchronized (waitLock) { + client.send(new SetState(SparkAppHandle.State.RUNNING)); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals(SparkAppHandle.State.RUNNING, handle.getState()); + + handle.stop(); + Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS); + assertTrue(stopMsg instanceof Stop); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + + @Test + public void testTimeout() throws Exception { + ChildProcAppHandle handle = null; + TestClient client = null; + try { + // LauncherServer will immediately close the server-side socket when the timeout is set + // to 0. + SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, "0"); + + handle = LauncherServer.newAppHandle(); + + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + client = new TestClient(s); + + // Try a few times since the client-side socket may not reflect the server-side close + // immediately. + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { + try { + if (!helloSent) { + client.send(new Hello(handle.getSecret(), "1.4.0")); + helloSent = true; + } else { + client.send(new SetAppId("appId")); + } + fail("Expected exception caused by connection timeout."); + } catch (IllegalStateException | IOException e) { + // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } + } + } + } finally { + SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + kill(handle); + close(client); + } + } + + private void kill(SparkAppHandle handle) { + if (handle != null) { + handle.kill(); + } + } + + private void close(Closeable c) { + if (c != null) { + try { + c.close(); + } catch (Exception e) { + // no-op. + } + } + } + + private static class TestClient extends LauncherConnection { + + final BlockingQueue inbound; + final Thread clientThread; + + TestClient(Socket s) throws IOException { + super(s); + this.inbound = new LinkedBlockingQueue(); + this.clientThread = new Thread(this); + clientThread.setName("TestClient"); + clientThread.setDaemon(true); + clientThread.start(); + } + + @Override + protected void handle(Message msg) throws IOException { + inbound.offer(msg); + } + + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 7329ac9f7fb8..6aad47adbcc8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -30,7 +30,7 @@ import org.junit.Test; import static org.junit.Assert.*; -public class SparkSubmitCommandBuilderSuite { +public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; @@ -48,12 +48,14 @@ public static void cleanUp() throws Exception { @Test public void testDriverCmdBuilder() throws Exception { - testCmdBuilder(true); + testCmdBuilder(true, true); + testCmdBuilder(true, false); } @Test public void testClusterCmdBuilder() throws Exception { - testCmdBuilder(false); + testCmdBuilder(false, true); + testCmdBuilder(false, false); } @Test @@ -149,7 +151,7 @@ public void testPySparkFallback() throws Exception { assertEquals("arg1", cmd.get(cmd.size() - 1)); } - private void testCmdBuilder(boolean isDriver) throws Exception { + private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = @@ -161,14 +163,20 @@ private void testCmdBuilder(boolean isDriver) throws Exception { launcher.appResource = "/foo"; launcher.appName = "MyApp"; launcher.mainClass = "my.Class"; - launcher.propertiesFile = dummyPropsFile.getAbsolutePath(); launcher.appArgs.add("foo"); launcher.appArgs.add("bar"); - launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); launcher.conf.put("spark.foo", "foo"); + // either set the property through "--conf" or through default property file + if (!useDefaultPropertyFile) { + launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); + launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); + } else { + launcher.childEnv.put("SPARK_CONF_DIR", System.getProperty("spark.test.home") + + "/launcher/src/test/resources"); + } Map env = new HashMap(); List cmd = launcher.buildCommand(env); @@ -216,7 +224,9 @@ private void testCmdBuilder(boolean isDriver) throws Exception { } // Checks below are the same for both driver and non-driver mode. - assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + if (!useDefaultPropertyFile) { + assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + } assertEquals("yarn", findArgValue(cmd, parser.MASTER)); assertEquals(deployMode, findArgValue(cmd, parser.DEPLOY_MODE)); assertEquals("my.Class", findArgValue(cmd, parser.CLASS)); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java index f3d210991705..3ee5b8cf9689 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -28,7 +28,7 @@ import static org.apache.spark.launcher.SparkSubmitOptionParser.*; -public class SparkSubmitOptionParserSuite { +public class SparkSubmitOptionParserSuite extends BaseSuite { private SparkSubmitOptionParser parser; diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 67a6a9821711..c64b1565e146 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -16,16 +16,19 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false - -# Some tests will set "test.name" to avoid overwriting the main log file. -log4j.appender.file.file=target/unit-tests${test.name}.log - +log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +log4j.appender.childproc=org.apache.log4j.ConsoleAppender +log4j.appender.childproc.target=System.err +log4j.appender.childproc.layout=org.apache.log4j.PatternLayout +log4j.appender.childproc.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/launcher/src/test/resources/spark-defaults.conf b/launcher/src/test/resources/spark-defaults.conf new file mode 100644 index 000000000000..239fc57883e9 --- /dev/null +++ b/launcher/src/test/resources/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +spark.driver.memory=1g +spark.driver.extraClassPath=/driver +spark.driver.extraJavaOptions=-Ddriver -XX:MaxPermSize=256m +spark.driver.extraLibraryPath=/native \ No newline at end of file diff --git a/licenses/LICENSE-AnchorJS.txt b/licenses/LICENSE-AnchorJS.txt new file mode 100644 index 000000000000..2bf24b9b9f84 --- /dev/null +++ b/licenses/LICENSE-AnchorJS.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-DPark.txt b/licenses/LICENSE-DPark.txt new file mode 100644 index 000000000000..1d916090e4ea --- /dev/null +++ b/licenses/LICENSE-DPark.txt @@ -0,0 +1,30 @@ +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses/LICENSE-Mockito.txt new file mode 100644 index 000000000000..e0840a446caf --- /dev/null +++ b/licenses/LICENSE-Mockito.txt @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2007 Mockito contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt new file mode 100644 index 000000000000..a538825d89ec --- /dev/null +++ b/licenses/LICENSE-SnapTree.txt @@ -0,0 +1,35 @@ +SNAPTREE LICENSE + +Copyright (c) 2009-2012 Stanford University, unless otherwise specified. +All rights reserved. + +This software was developed by the Pervasive Parallelism Laboratory of +Stanford University, California, USA. + +Permission to use, copy, modify, and distribute this software in source +or binary form for any purpose with or without fee is hereby granted, +provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of Stanford University nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +SUCH DAMAGE. diff --git a/licenses/LICENSE-antlr.txt b/licenses/LICENSE-antlr.txt new file mode 100644 index 000000000000..3021ea04332e --- /dev/null +++ b/licenses/LICENSE-antlr.txt @@ -0,0 +1,8 @@ +[The BSD License] +Copyright (c) 2012 Terence Parr and Sam Harwell +All rights reserved. +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt new file mode 100644 index 000000000000..7bba0cd9e10a --- /dev/null +++ b/licenses/LICENSE-boto.txt @@ -0,0 +1,20 @@ +Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, dis- +tribute, sublicense, and/or sell copies of the Software, and to permit +persons to whom the Software is furnished to do so, subject to the fol- +lowing conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- +ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-cloudpickle.txt b/licenses/LICENSE-cloudpickle.txt new file mode 100644 index 000000000000..b1e20fa1eda8 --- /dev/null +++ b/licenses/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-d3.min.js.txt b/licenses/LICENSE-d3.min.js.txt new file mode 100644 index 000000000000..c71e3f254c06 --- /dev/null +++ b/licenses/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-dagre-d3.txt b/licenses/LICENSE-dagre-d3.txt new file mode 100644 index 000000000000..4864fe05e980 --- /dev/null +++ b/licenses/LICENSE-dagre-d3.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses/LICENSE-f2j.txt new file mode 100644 index 000000000000..e28fd3ccdfa6 --- /dev/null +++ b/licenses/LICENSE-f2j.txt @@ -0,0 +1,8 @@ +Copyright © 2015 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt new file mode 100644 index 000000000000..c9e18cd56242 --- /dev/null +++ b/licenses/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-heapq.txt b/licenses/LICENSE-heapq.txt new file mode 100644 index 000000000000..0c4c4b954bea --- /dev/null +++ b/licenses/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses/LICENSE-javolution.txt new file mode 100644 index 000000000000..b64af4d8298a --- /dev/null +++ b/licenses/LICENSE-javolution.txt @@ -0,0 +1,27 @@ +/* + * Javolution - Java(tm) Solution for Real-Time and Embedded Systems + * Copyright (c) 2012, Javolution (http://javolution.org/) + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ \ No newline at end of file diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt new file mode 100644 index 000000000000..d332534c0635 --- /dev/null +++ b/licenses/LICENSE-jbcrypt.txt @@ -0,0 +1,17 @@ +jBCrypt is subject to the following license: + +/* + * Copyright (c) 2006 Damien Miller + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ diff --git a/licenses/LICENSE-jblas.txt b/licenses/LICENSE-jblas.txt new file mode 100644 index 000000000000..5629dafb65b3 --- /dev/null +++ b/licenses/LICENSE-jblas.txt @@ -0,0 +1,31 @@ +Copyright (c) 2009, Mikio L. Braun and contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of the Technische Universität Berlin nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior + written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jline.txt b/licenses/LICENSE-jline.txt new file mode 100644 index 000000000000..2ec539d10ac5 --- /dev/null +++ b/licenses/LICENSE-jline.txt @@ -0,0 +1,32 @@ +Copyright (c) 2002-2006, Marc Prud'hommeaux +All rights reserved. + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: + +Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with +the distribution. + +Neither the name of JLine nor the names of its contributors +may be used to endorse or promote products derived from this +software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, +BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED +AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED +OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses/LICENSE-jpmml-model.txt new file mode 100644 index 000000000000..69411d1c6e9a --- /dev/null +++ b/licenses/LICENSE-jpmml-model.txt @@ -0,0 +1,10 @@ +Copyright (c) 2009, University of Tartu +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt new file mode 100644 index 000000000000..e1dd696d3b6c --- /dev/null +++ b/licenses/LICENSE-jquery.txt @@ -0,0 +1,9 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-junit-interface.txt b/licenses/LICENSE-junit-interface.txt new file mode 100644 index 000000000000..e835350c4e2a --- /dev/null +++ b/licenses/LICENSE-junit-interface.txt @@ -0,0 +1,24 @@ +Copyright (c) 2009-2012, Stefan Zeiger +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses/LICENSE-kryo.txt new file mode 100644 index 000000000000..3f6a160c238e --- /dev/null +++ b/licenses/LICENSE-kryo.txt @@ -0,0 +1,10 @@ +Copyright (c) 2008, Nathan Sweet +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + * Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses/LICENSE-minlog.txt new file mode 100644 index 000000000000..3f6a160c238e --- /dev/null +++ b/licenses/LICENSE-minlog.txt @@ -0,0 +1,10 @@ +Copyright (c) 2008, Nathan Sweet +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + * Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses/LICENSE-netlib.txt new file mode 100644 index 000000000000..75783ed6bc35 --- /dev/null +++ b/licenses/LICENSE-netlib.txt @@ -0,0 +1,49 @@ +Copyright (c) 2013 Samuel Halliday +Copyright (c) 1992-2011 The University of Tennessee and The University + of Tennessee Research Foundation. All rights + reserved. +Copyright (c) 2000-2011 The University of California Berkeley. All + rights reserved. +Copyright (c) 2006-2011 The University of Colorado Denver. All rights + reserved. + +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer listed + in this license in the documentation and/or other materials + provided with the distribution. + +- Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +The copyright holders provide no reassurances that the source code +provided does not infringe any patent, copyright, or any other +intellectual property rights of third parties. The copyright holders +disclaim any liability to any recipient for claims brought against +recipient by any third party for infringement of that parties +intellectual property rights. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-paranamer.txt b/licenses/LICENSE-paranamer.txt new file mode 100644 index 000000000000..fca18473ba03 --- /dev/null +++ b/licenses/LICENSE-paranamer.txt @@ -0,0 +1,28 @@ +[ ParaNamer used to be 'Pubic Domain', but since it includes a small piece of ASM it is now the same license as that: BSD ] + + Copyright (c) 2006 Paul Hammant & ThoughtWorks Inc + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF + THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-protobuf.txt b/licenses/LICENSE-protobuf.txt new file mode 100644 index 000000000000..b4350ec83c75 --- /dev/null +++ b/licenses/LICENSE-protobuf.txt @@ -0,0 +1,42 @@ +This license applies to all parts of Protocol Buffers except the following: + + - Atomicops support for generic gcc, located in + src/google/protobuf/stubs/atomicops_internals_generic_gcc.h. + This file is copyrighted by Red Hat Inc. + + - Atomicops support for AIX/POWER, located in + src/google/protobuf/stubs/atomicops_internals_aix.h. + This file is copyrighted by Bloomberg Finance LP. + +Copyright 2014, Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. \ No newline at end of file diff --git a/licenses/LICENSE-py4j.txt b/licenses/LICENSE-py4j.txt new file mode 100644 index 000000000000..70af3e69ed67 --- /dev/null +++ b/licenses/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses/LICENSE-pyrolite.txt new file mode 100644 index 000000000000..9457c7aa6614 --- /dev/null +++ b/licenses/LICENSE-pyrolite.txt @@ -0,0 +1,28 @@ + +Pyro - Python Remote Objects +Software License, copyright, and disclaimer + + Pyro is Copyright (c) by Irmen de Jong (irmen@razorvine.net). + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + +This is the "MIT Software License" which is OSI-certified, and GPL-compatible. +See http://www.opensource.org/licenses/mit-license.php + diff --git a/licenses/LICENSE-reflectasm.txt b/licenses/LICENSE-reflectasm.txt new file mode 100644 index 000000000000..3f6a160c238e --- /dev/null +++ b/licenses/LICENSE-reflectasm.txt @@ -0,0 +1,10 @@ +Copyright (c) 2008, Nathan Sweet +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + * Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-sbt-launch-lib.txt b/licenses/LICENSE-sbt-launch-lib.txt new file mode 100644 index 000000000000..3b9156baaab7 --- /dev/null +++ b/licenses/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses/LICENSE-scala.txt new file mode 100644 index 000000000000..4846076aba24 --- /dev/null +++ b/licenses/LICENSE-scala.txt @@ -0,0 +1,30 @@ +Copyright (c) 2002-2013 EPFL +Copyright (c) 2011-2013 Typesafe, Inc. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +- Neither the name of the EPFL nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt new file mode 100644 index 000000000000..cb8f97842f4c --- /dev/null +++ b/licenses/LICENSE-scalacheck.txt @@ -0,0 +1,32 @@ +ScalaCheck LICENSE + +Copyright (c) 2007-2015, Rickard Nilsson +All rights reserved. + +Permission to use, copy, modify, and distribute this software in source +or binary form for any purpose with or without fee is hereby granted, +provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of the author nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scopt.txt b/licenses/LICENSE-scopt.txt new file mode 100644 index 000000000000..2bf24b9b9f84 --- /dev/null +++ b/licenses/LICENSE-scopt.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses/LICENSE-slf4j.txt new file mode 100644 index 000000000000..6548cd3af432 --- /dev/null +++ b/licenses/LICENSE-slf4j.txt @@ -0,0 +1,21 @@ +Copyright (c) 2004-2013 QOS.ch + All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-sorttable.js.txt b/licenses/LICENSE-sorttable.js.txt new file mode 100644 index 000000000000..b31a5b206bf4 --- /dev/null +++ b/licenses/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses/LICENSE-spire.txt new file mode 100644 index 000000000000..40af7746b931 --- /dev/null +++ b/licenses/LICENSE-spire.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2012 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses/LICENSE-xmlenc.txt new file mode 100644 index 000000000000..3a70c9bfcdad --- /dev/null +++ b/licenses/LICENSE-xmlenc.txt @@ -0,0 +1,27 @@ +Copyright 2003-2005, Ernst de Haan +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/make-distribution.sh b/make-distribution.sh index 4789b0e09cc8..e64ceb802464 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,9 +33,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.7.0" +TACHYON_VERSION="0.8.2" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" +TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" MAKE_TGZ=false NAME=none @@ -69,9 +69,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; --with-tachyon) SPARK_TACHYON=true ;; @@ -121,7 +118,7 @@ if [ $(command -v git) ]; then fi -if [ ! $(command -v "$MVN") ] ; then +if [ ! "$(command -v "$MVN")" ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; @@ -198,6 +195,7 @@ fi # Copy license and ASF files cp "$SPARK_HOME/LICENSE" "$DISTDIR" +cp -r "$SPARK_HOME/licenses" "$DISTDIR" cp "$SPARK_HOME/NOTICE" "$DISTDIR" if [ -e "$SPARK_HOME"/CHANGES.txt ]; then @@ -240,10 +238,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx diff --git a/mllib/pom.xml b/mllib/pom.xml index a5db14407b4f..df50aca1a3f7 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -94,16 +94,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core @@ -119,7 +109,7 @@ org.jpmml pmml-model - 1.1.15 + 1.2.7 com.sun.xml.fastinfoset @@ -131,6 +121,10 @@ + + org.apache.spark + spark-test-tags_${scala.binary.version} + diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..f632dd603c44 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.ml.source.libsvm.DefaultSource diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d287..4b2b3f8489fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,10 +22,16 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -82,7 +88,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +172,104 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + @Since("1.6.0") + override def write: MLWriter = new Pipeline.PipelineWriter(this) +} + +@Since("1.6.0") +object Pipeline extends MLReadable[Pipeline] { + + @Since("1.6.0") + override def read: MLReader[Pipeline] = new PipelineReader + + @Since("1.6.0") + override def load(path: String): Pipeline = super.load(path) + + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { + + SharedReadWrite.validateStages(instance.getStages) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private class PipelineReader extends MLReader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = classOf[Pipeline].getName + + override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } + } + + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${other.uid} of type ${other.getClass}") + } + } + + /** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UID + */ + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) + } + } + + /** + * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * @return (UID, list of stages) + */ + def load( + expectedClassName: String, + sc: SparkContext, + path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) + } + (metadata.uid, stages) + } + + /** Get path for saving the given stage. */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } + } } /** @@ -176,7 +280,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -198,6 +302,45 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) +} + +@Since("1.6.0") +object PipelineModel extends MLReadable[PipelineModel] { + + import Pipeline.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[PipelineModel] = new PipelineModelReader + + @Since("1.6.0") + override def load(path: String): PipelineModel = super.load(path) + + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + + SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + + override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) + } + + private class PipelineModelReader extends MLReader[PipelineModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[PipelineModel].getName + + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") + } + new PipelineModel(uid, transformers) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 19fe039b8fd0..e0dcd427fae2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** @group setParam */ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + /** Returns the number of features the model was trained on. If unknown, returns -1 */ + @Since("1.6.0") + def numFeatures: Int = -1 + /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd3..2c29eeb01a92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e479f169021d..a7c10333c0d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -124,18 +124,28 @@ private[attribute] trait AttributeFactory { private[attribute] def fromMetadata(metadata: Metadata): Attribute /** - * Creates an [[Attribute]] from a [[StructField]] instance. + * Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name. */ - def fromStructField(field: StructField): Attribute = { + private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = { require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { - fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) + val attr = fromMetadata(metadata.getMetadata(mlAttr)) + if (preserveName) { + attr + } else { + attr.withName(field.name) + } } else { UnresolvedAttribute } } + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = decodeStructField(field, false) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 581d8fa7749b..45df557a8990 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,14 +18,13 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index f2b992f8ba24..8c4cec132665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest @@ -36,32 +36,46 @@ import org.apache.spark.sql.DataFrame * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeClassifier(override val uid: String) +final class DecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) // Override parameter setters from parent trait for Java API compatibility. + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + @Since("1.6.0") + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -75,7 +89,7 @@ final class DecisionTreeClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } @@ -87,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.1") override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } @@ -102,11 +119,13 @@ object DecisionTreeClassifier { * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental final class DecisionTreeClassificationModel private[ml] ( - override val uid: String, - override val rootNode: Node, - override val numClasses: Int) + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: Node, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -117,8 +136,8 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node, numClasses: Int) = - this(Identifiable.randomUID("dtc"), rootNode, numClasses) + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction @@ -139,12 +158,15 @@ final class DecisionTreeClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) + .setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } /** (private[ml]) Convert to a model in the old API */ @@ -159,12 +181,14 @@ private[ml] object DecisionTreeClassificationModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, - categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeClassificationModel = { require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode, -1) + // Can't infer number of features from old model, so default to -1 + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a959926..cda2bca58c50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ +@Since("1.4.0") @Experimental -final class GBTClassifier(override val uid: String) +final class GBTClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this @@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) @@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String) // Parameters from GBTParams: + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTClassifier: @@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String) * (default = logistic) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", @@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String) setDefault(lossType -> "logistic") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -138,19 +156,23 @@ final class GBTClassifier(override val uid: String) require(numClasses == 2, s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) + GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) } @@ -163,11 +185,13 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.6.0") @Experimental -final class GBTClassificationModel( - override val uid: String, +final class GBTClassificationModel private[ml]( + @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -175,8 +199,19 @@ final class GBTClassificationModel( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + @Since("1.6.0") + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -195,12 +230,15 @@ final class GBTClassificationModel( if (prediction > 0.0) 1.0 else 0.0 } + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"GBTClassificationModel with $numTrees trees" + s"GBTClassificationModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ @@ -215,7 +253,8 @@ private[ml] object GBTClassificationModel { def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, - categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -223,6 +262,6 @@ private[ml] object GBTClassificationModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fc9199fb460..486043e8d974 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,19 +21,22 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.storage.StorageLevel /** @@ -41,18 +44,124 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasThreshold with HasStandardization + with HasStandardization with HasWeightCol with HasThreshold { + + /** + * Set threshold in binary classification, in range [0, 1]. + * + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * @group setParam + */ + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } + + /** + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. + * + * @group getParam + */ + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) + } else { + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") + } + } + + override def validateParams(): Unit = { + checkThresholdConsistency() + } +} /** * :: Experimental :: * Logistic regression. - * Currently, this class only supports binary classification. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ +@Since("1.2.0") @Experimental -class LogisticRegression(override val uid: String) +class LogisticRegression @Since("1.2.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("logreg")) /** @@ -60,6 +169,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.2.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -70,6 +180,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -78,6 +189,7 @@ class LogisticRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.2.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -87,6 +199,7 @@ class LogisticRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -94,47 +207,69 @@ class LogisticRegression(override val uid: String) * Whether to fit an intercept term. * Default is true. * @group setParam - * */ + */ + @Since("1.4.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) /** * Whether to standardize the training features before fitting the model. * The coefficients of models will be always returned on the original scale, - * so it will be transparent for users. Note that when no regularization, - * with or without standardization, the models should be always converged to - * the same solution. + * so it will be transparent for users. Note that with/without standardization, + * the models should be always converged to the same solution when no regularization + * is applied. In R's GLMNET package, the default behavior is true as well. * Default is true. * @group setParam - * */ + */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) - /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) - setDefault(threshold -> 0.5) + @Since("1.5.0") + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + @Since("1.5.0") + override def getThreshold: Double = super.getThreshold + + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("1.6.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + @Since("1.5.0") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + @Since("1.5.0") + override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, labelSummarizer) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), - (label: Double, features: Vector)) => - (summarizer.add(features), labelSummarizer.add(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, - classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, - classSummarizer2: MultiClassSummarizer)) => - (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) - }) + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp) + } val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -187,12 +322,12 @@ class LogisticRegression(override val uid: String) new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialWeightsWithIntercept = + val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if ($(fitIntercept)) { /* - For binary logistic regression, when we initialize the weights as zeros, + For binary logistic regression, when we initialize the coefficients as zeros, it will converge faster if we initialize the intercept such that it follows the distribution of the labels. @@ -204,14 +339,14 @@ class LogisticRegression(override val uid: String) b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialWeightsWithIntercept.toArray(numFeatures) - = math.log(histogram(1).toDouble / histogram(0).toDouble) + initialCoefficientsWithIntercept.toArray(numFeatures) + = math.log(histogram(1) / histogram(0)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeightsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.toBreeze.toDenseVector) - val (weights, intercept, objectiveHistory) = { + val (coefficients, intercept, objectiveHistory) = { /* Note that in Logistic Regression, the objective history (loss + regularization) is log-likelihood which is invariance under feature standardization. As a result, @@ -231,51 +366,81 @@ class LogisticRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 while (i < numFeatures) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } if ($(fitIntercept)) { - (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) } else { - (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) } } if (handlePersistence) instances.unpersist() - copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() + val logRegSummary = new BinaryLogisticRegressionTrainingSummary( + summaryModel.transform(dataset), + probabilityColName, + $(labelCol), + $(featuresCol), + objectiveHistory) + model.setSummary(logRegSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } +@Since("1.6.0") +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + @Since("1.6.0") + override def load(path: String): LogisticRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[LogisticRegression]]. */ +@Since("1.4.0") @Experimental class LogisticRegressionModel private[ml] ( - override val uid: String, - val weights: Vector, - val intercept: Double) + @Since("1.4.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams { + with LogisticRegressionParams with MLWritable { + + @deprecated("Use coefficients instead.", "1.6.0") + def weights: Vector = coefficients + + @Since("1.5.0") + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + @Since("1.5.0") + override def getThreshold: Double = super.getThreshold + + @Since("1.5.0") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) - /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) + @Since("1.5.0") + override def getThresholds: Array[Double] = super.getThresholds /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { - BLAS.dot(features, weights) + intercept + BLAS.dot(features, coefficients) + intercept } /** Score (probability) for class label 1. For binary classification only. */ @@ -284,13 +449,68 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + @Since("1.6.0") + override val numFeatures: Int = coefficients.size + + @Since("1.3.0") override val numClasses: Int = 2 + private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("1.5.0") + def summary: LogisticRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LogisticRegressionModel", + new NullPointerException()) + } + + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + + private[classification] def setSummary( + summary: LogisticRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + new BinaryLogisticRegressionSummary( + this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + } + /** * Predict label for the given feature vector. - * The behavior of this can be adjusted using [[threshold]]. + * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (score(features) > getThreshold) 1 else 0 } @@ -315,11 +535,15 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(-m, m) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. val t = getThreshold val rawThreshold = if (t == 0.0) { Double.NegativeInfinity @@ -332,10 +556,80 @@ class LogisticRegressionModel private[ml] ( } override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } + + /** + * Returns a [[MLWriter]] instance for this ML instance. + * + * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } + +@Since("1.6.0") +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LogisticRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + extends MLWriter with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LogisticRegressionModelReader + extends MLReader[LogisticRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[LogisticRegressionModel].getName + + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + + /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification @@ -345,22 +639,29 @@ class LogisticRegressionModel private[ml] ( * corresponding joint dataset. */ private[classification] class MultiClassSummarizer extends Serializable { - private val distinctMap = new mutable.HashMap[Int, Long] + // The first element of value in distinctMap is the actually number of instances, + // and the second element of value is sum of the weights. + private val distinctMap = new mutable.HashMap[Int, (Long, Double)] private var totalInvalidCnt: Long = 0L /** * Add a new label into this MultilabelSummarizer, and update the distinct map. * @param label The label for this data point. + * @param weight The weight of this instances. * @return This MultilabelSummarizer */ - def add(label: Double): this.type = { + def add(label: Double, weight: Double = 1.0): this.type = { + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + if (label - label.toInt != 0.0 || label < 0) { totalInvalidCnt += 1 this } else { - val counts: Long = distinctMap.getOrElse(label.toInt, 0L) - distinctMap.put(label.toInt, counts + 1) + val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0)) + distinctMap.put(label.toInt, (counts + 1L, weightSum + weight)) this } } @@ -381,8 +682,8 @@ private[classification] class MultiClassSummarizer extends Serializable { } smallMap.distinctMap.foreach { case (key, value) => - val counts = largeMap.distinctMap.getOrElse(key, 0L) - largeMap.distinctMap.put(key, counts + value) + val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0)) + largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2)) } largeMap.totalInvalidCnt += smallMap.totalInvalidCnt largeMap @@ -394,29 +695,187 @@ private[classification] class MultiClassSummarizer extends Serializable { /** @return The number of distinct labels in the input dataset. */ def numClasses: Int = distinctMap.keySet.max + 1 - /** @return The counts of each label in the input dataset. */ - def histogram: Array[Long] = { - val result = Array.ofDim[Long](numClasses) + /** @return The weightSum of each label in the input dataset. */ + def histogram: Array[Double] = { + val result = Array.ofDim[Double](numClasses) var i = 0 val len = result.length while (i < len) { - result(i) = distinctMap.getOrElse(i, 0L) + result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2 i += 1 } result } } +/** + * Abstraction for multinomial Logistic Regression Training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. + */ +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { + + /** objective function (scaled loss + regularization) at each iteration. */ + def objectiveHistory: Array[Double] + + /** Number of training iterations until termination */ + def totalIterations: Int = objectiveHistory.length + +} + +/** + * Abstraction for Logistic Regression Results for a given model. + */ +sealed trait LogisticRegressionSummary extends Serializable { + + /** Dataframe outputted by the model's `transform` method. */ + def predictions: DataFrame + + /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ + def probabilityCol: String + + /** Field in "predictions" which gives the true label of each instance. */ + def labelCol: String + + /** Field in "predictions" which gives the features of each instance as a vector. */ + def featuresCol: String + +} + +/** + * :: Experimental :: + * Logistic regression training results. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each instance as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +@Since("1.5.0") +class BinaryLogisticRegressionTrainingSummary private[classification] ( + @Since("1.5.0") predictions: DataFrame, + @Since("1.5.0") probabilityCol: String, + @Since("1.5.0") labelCol: String, + @Since("1.6.0") featuresCol: String, + @Since("1.5.0") val objectiveHistory: Array[Double]) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) + with LogisticRegressionTrainingSummary { + +} + +/** + * :: Experimental :: + * Binary Logistic regression results for a given model. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + */ +@Experimental +@Since("1.5.0") +class BinaryLogisticRegressionSummary private[classification] ( + @Since("1.5.0") @transient override val predictions: DataFrame, + @Since("1.5.0") override val probabilityCol: String, + @Since("1.5.0") override val labelCol: String, + @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { + + + private val sqlContext = predictions.sqlContext + import sqlContext.implicits._ + + /** + * Returns a BinaryClassificationMetrics object. + */ + // TODO: Allow the user to vary the number of bins using a setBins method in + // BinaryClassificationMetrics. For now the default is set to 100. + @transient private val binaryMetrics = new BinaryClassificationMetrics( + predictions.select(probabilityCol, labelCol).map { + case Row(score: Vector, label: Double) => (score(1), label) + }, 100 + ) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an Dataframe having two fields (FPR, TPR) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + @Since("1.5.0") + @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + */ + @Since("1.5.0") + lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() + + /** + * Returns the precision-recall curve, which is an Dataframe containing + * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + */ + @Since("1.5.0") + @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") + + /** + * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + */ + @Since("1.5.0") + @transient lazy val fMeasureByThreshold: DataFrame = { + binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") + } + + /** + * Returns a dataframe with two fields (threshold, precision) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + */ + @Since("1.5.0") + @transient lazy val precisionByThreshold: DataFrame = { + binaryMetrics.precisionByThreshold().toDF("threshold", "precision") + } + + /** + * Returns a dataframe with two fields (threshold, recall) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. + */ + @Since("1.5.0") + @transient lazy val recallByThreshold: DataFrame = { + binaryMetrics.recallByThreshold().toDF("threshold", "recall") + } +} + /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for samples in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in a online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. @@ -424,79 +883,84 @@ private[classification] class MultiClassSummarizer extends Serializable { * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - weights: Vector, + coefficients: Vector, numClasses: Int, fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { - private var totalCnt: Long = 0L + private var weightSum = 0.0 private var lossSum = 0.0 - private val weightsArray = weights match { + private val coefficientsArray = coefficients match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") } - private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) /** - * Add a new training data to this LogisticAggregator, and update the loss and gradient + * Add a new training instance to this LogisticAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The instance of data point to be added. * @return This LogisticAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") - - val localWeightsArray = weightsArray - val localGradientSumArray = gradientSumArray - - numClasses match { - case 2 => - // For Binary Logistic Regression. - val margin = - { - var sum = 0.0 - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localWeightsArray(index) * (value / featuresStd(index)) + def add(instance: Instance): this.type = { + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + + val localCoefficientsArray = coefficientsArray + val localGradientSumArray = gradientSumArray + + numClasses match { + case 2 => + // For Binary Logistic Regression. + val margin = - { + var sum = 0.0 + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += localCoefficientsArray(index) * (value / featuresStd(index)) + } + } + sum + { + if (fitIntercept) localCoefficientsArray(dim) else 0.0 } } - sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } - } - val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label) - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += multiplier * (value / featuresStd(index)) + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += multiplier * (value / featuresStd(index)) + } } - } - if (fitIntercept) { - localGradientSumArray(dim) += multiplier - } + if (fitIntercept) { + localGradientSumArray(dim) += multiplier + } - if (label > 0) { - // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - lossSum += MLUtils.log1pExp(margin) - } else { - lossSum += MLUtils.log1pExp(margin) - margin - } - case _ => - new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + - "binary classification for now.") + if (label > 0) { + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + lossSum += weight * MLUtils.log1pExp(margin) + } else { + lossSum += weight * (MLUtils.log1pExp(margin) - margin) + } + case _ => + new NotImplementedError("LogisticRegression with ElasticNet in ML package " + + "only supports binary classification for now.") + } + weightSum += weight + this } - totalCnt += 1 - this } /** @@ -511,8 +975,8 @@ private class LogisticAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { - totalCnt += other.totalCnt + if (other.weightSum != 0.0) { + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -527,13 +991,17 @@ private class LogisticAggregator( this } - def count: Long = totalCnt - - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -541,11 +1009,11 @@ private class LogisticAggregator( /** * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function, * as used in multi-class classification (it is also used in binary logistic regression). - * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It returns the loss and gradient with L2 regularization at a particular point (coefficients). * It's used in Breeze's convex optimization routines. */ private class LogisticCostFun( - data: RDD[(Double, Vector)], + instances: RDD[Instance], numClasses: Int, fitIntercept: Boolean, standardization: Boolean, @@ -553,27 +1021,27 @@ private class LogisticCostFun( featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length - val w = Vectors.fromBreeze(weights) + val coeffs = Vectors.fromBreeze(coefficients) - val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, - featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val logisticAggregator = { + val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) + val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) + + instances.treeAggregate( + new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean) + )(seqOp, combOp) + } val totalGradientArray = logisticAggregator.gradient.toArray - // regVal is the sum of weight squares excluding intercept for L2 regularization. + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeffs.foreachActive { (index, value) => // If `fitIntercept` is true, the last term which is intercept doesn't // contribute to the regularization. if (index != numFeatures) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8cd2103d7d5e..a691aa005ef5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} @@ -32,6 +34,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams with HasSeed with HasMaxIter with HasTol { /** * Layer sizes including input size and output size. + * Default: Array(1, 1) * @group param */ final val layers: IntArrayParam = new IntArrayParam(this, "layers", @@ -42,9 +45,6 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams ParamValidators.arrayLengthGt(1) ) - /** @group setParam */ - def setLayers(value: Array[Int]): this.type = set(layers, value) - /** @group getParam */ final def getLayers: Array[Int] = $(layers) @@ -53,6 +53,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams * Data is stacked within partitions. If block size is more than remaining data in * a partition then it is adjusted to the size of this data. * Recommended size is between 10 and 1000. + * Default: 128 * @group expertParam */ final val blockSize: IntParam = new IntParam(this, "blockSize", @@ -61,33 +62,9 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams "it is adjusted to the size of this data. Recommended size is between 10 and 1000", ParamValidators.gt(0)) - /** @group setParam */ - def setBlockSize(value: Int): this.type = set(blockSize, value) - /** @group getParam */ final def getBlockSize: Int = $(blockSize) - /** - * Set the maximum number of iterations. - * Default is 100. - * @group setParam - */ - def setMaxIter(value: Int): this.type = set(maxIter, value) - - /** - * Set the convergence tolerance of iterations. - * Smaller value will lead to higher accuracy with the cost of more iterations. - * Default is 1E-4. - * @group setParam - */ - def setTol(value: Double): this.type = set(tol, value) - - /** - * Set the seed for weights initialization. - * @group setParam - */ - def setSeed(value: Long): this.type = set(seed, value) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) } @@ -127,15 +104,50 @@ private object LabelConverter { * Each layer has sigmoid activation function, output layer has softmax. * Number of inputs has to be equal to the size of feature vectors. * Number of outputs has to be equal to the total number of labels. - * */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifier(override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] +class MultilayerPerceptronClassifier @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) + /** @group setParam */ + @Since("1.5.0") + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group setParam */ + @Since("1.5.0") + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.5.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + @Since("1.5.0") + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + @Since("1.5.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -146,7 +158,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) @@ -156,29 +168,40 @@ class MultilayerPerceptronClassifier(override val uid: String) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) } } /** * :: Experimental :: - * Classifier model based on the Multilayer Perceptron. + * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifierModel private[ml] ( - override val uid: String, - layers: Array[Int], - weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] +class MultilayerPerceptronClassificationModel private[ml] ( + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val layers: Array[Int], + @Since("1.5.0") val weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + @Since("1.6.0") + override val numFeatures: Int = layers.head + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. @@ -187,7 +210,8 @@ class MultilayerPerceptronClassifierModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } - override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { - copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + @Since("1.5.0") + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b46b676204e0..718f49d3aedc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ @@ -59,6 +62,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { } /** + * :: Experimental :: * Naive Bayes Classifiers. * It supports both Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) @@ -68,10 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ -class NaiveBayes(override val uid: String) +@Since("1.5.0") +@Experimental +class NaiveBayes @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) /** @@ -79,6 +87,7 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ + @Since("1.5.0") def setSmoothing(value: Double): this.type = set(smoothing, value) setDefault(smoothing -> 1.0) @@ -86,7 +95,9 @@ class NaiveBayes(override val uid: String) * Set the model type using a string (case-sensitive). * Supported options: "multinomial" and "bernoulli". * Default is "multinomial" + * @group setParam */ + @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -96,17 +107,32 @@ class NaiveBayes(override val uid: String) NaiveBayesModel.fromOld(oldModel, this) } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** + * :: Experimental :: * Model produced by [[NaiveBayes]] + * @param pi log of class priors, whose dimension is C (number of classes) + * @param theta log of class conditional probabilities, whose dimension is C (number of classes) + * by D (number of features) */ +@Since("1.5.0") +@Experimental class NaiveBayesModel private[ml] ( - override val uid: String, - val pi: Vector, - val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val pi: Vector, + @Since("1.5.0") val theta: Matrix) + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,6 +155,10 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + @Since("1.6.0") + override val numFeatures: Int = theta.numCols + + @Since("1.5.0") override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { @@ -185,20 +215,25 @@ class NaiveBayesModel private[ml] ( } } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } + @Since("1.5.0") override def toString: String = { - s"NaiveBayesModel with ${pi.size} classes" + s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -208,4 +243,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911..08a51109d6c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap} @@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ +@Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -91,7 +94,6 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString val initUDF = udf { () => Map[Int, Double]() } - val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. @@ -131,14 +133,15 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } @@ -151,30 +154,39 @@ final class OneVsRestModel private[ml] ( * Each example is scored against all k models and the model with highest score * is picked to label the example. */ +@Since("1.4.0") @Experimental -final class OneVsRest(override val uid: String) +final class OneVsRest @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ + @Since("1.4.0") def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } + @Since("1.4.0") override def fit(dataset: DataFrame): OneVsRestModel = { // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) @@ -195,16 +207,11 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - val labelUDF = udf { (label: Double) => - if (label.toInt == index) 1.0 else 0.0 - } - // generate new label metadata for the binary problem. - // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = multiclassLabeled.withColumn( + labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) @@ -228,6 +235,7 @@ final class OneVsRest(override val uid: String) copyValues(model) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRest = { val copied = defaultCopy(extra).asInstanceOf[OneVsRest] if (isDefined(classifier)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index f9c9c2371f5c..fdd1851ae550 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,17 +20,16 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, DataType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} /** * (private[classification]) Params for probabilistic classification. */ private[classification] trait ProbabilisticClassifierParams - extends ClassifierParams with HasProbabilityCol { - + extends ClassifierParams with HasProbabilityCol with HasThresholds { override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, @@ -51,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassifier[ +abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[ /** @group setParam */ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] + + /** @group setParam */ + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] } @@ -72,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[ * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassificationModel[ +abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { @@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ /** @group setParam */ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + /** @group setParam */ + def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: @@ -92,6 +97,11 @@ private[spark] abstract class ProbabilisticClassificationModel[ */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".transform() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -155,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[ raw2probabilityInPlace(probs) } + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (!isDefined(thresholds)) { + rawPrediction.argmax + } else { + probability2prediction(raw2probability(rawPrediction)) + } + } + /** * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. @@ -170,10 +188,21 @@ private[spark] abstract class ProbabilisticClassificationModel[ /** * Given a vector of class conditional probabilities, select the predicted label. - * This may be overridden to support thresholds which favor particular labels. + * This supports thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.argmax + protected def probability2prediction(probability: Vector): Double = { + if (!isDefined(thresholds)) { + probability.argmax + } else { + val thresholds: Array[Double] = getThresholds + val scaledProbability: Array[Double] = + probability.toArray.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + Vectors.dense(scaledProbability).argmax + } + } } private[ml] object ProbabilisticClassificationModel { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 56e80cc8fe6e..d6d85ad2533a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} @@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._ * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class RandomForestClassifier(override val uid: String) +final class RandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -95,18 +110,23 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees, numClasses) + val numFeatures = oldDataset.first().features.size + new RandomForestClassificationModel(trees, numFeatures, numClasses) } + @Since("1.4.1") override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -119,11 +139,13 @@ object RandomForestClassifier { * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ +@Since("1.4.0") @Experimental final class RandomForestClassificationModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -133,14 +155,19 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = - this(Identifiable.randomUID("rfc"), trees, numClasses) + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -181,14 +208,34 @@ final class RandomForestClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"RandomForestClassificationModel with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) @@ -202,7 +249,8 @@ private[ml] object RandomForestClassificationModel { oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], - numClasses: Int): RandomForestClassificationModel = { + numClasses: Int, + numFeatures: Int = -1): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -210,6 +258,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, numClasses) + new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc192add6ca1..71e968497500 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,69 +17,48 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils - /** * Common params for KMeans and KMeansModel */ -private[clustering] trait KMeansParams - extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. * @group param */ + @Since("1.5.0") final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) /** @group getParam */ + @Since("1.5.0") def getK: Int = $(k) - /** - * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Must be >= 1. Default: 1. - * @group param - */ - final val runs = new IntParam(this, "runs", - "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) - - /** @group getParam */ - def getRuns: Int = $(runs) - - /** - * Param the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - * Must be >= 0.0. Default: 1e-4 - * @group param - */ - final val epsilon = new DoubleParam(this, "epsilon", - "distance threshold within which we've consider centers to have converge", - (value: Double) => value >= 0.0) - - /** @group getParam */ - def getEpsilon: Double = $(epsilon) - /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. * @group expertParam */ + @Since("1.5.0") final val initMode = new Param[String](this, "initMode", "initialization algorithm", (value: String) => MLlibKMeans.validateInitMode(value)) /** @group expertGetParam */ + @Since("1.5.0") def getInitMode: String = $(initMode) /** @@ -87,10 +66,12 @@ private[clustering] trait KMeansParams * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. * @group expertParam */ + @Since("1.5.0") final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", (value: Int) => value > 0) /** @group expertGetParam */ + @Since("1.5.0") def getInitSteps: Int = $(initSteps) /** @@ -110,78 +91,152 @@ private[clustering] trait KMeansParams * * @param parentModel a model trained by spark.mllib.clustering.KMeans. */ +@Since("1.5.0") @Experimental class KMeansModel private[ml] ( - override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + @Since("1.5.0") override val uid: String, + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { + @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = new KMeansModel(uid, parentModel) copyValues(copied, extra) } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + @Since("1.5.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @Since("1.6.0") + def computeCost(dataset: DataFrame): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + parentModel.computeCost(data) + } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** * :: Experimental :: - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ +@Since("1.5.0") @Experimental -class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { +class KMeans @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, maxIter -> 20, - runs -> 1, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 5, - epsilon -> 1e-4) + tol -> 1e-4) + @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + @Since("1.5.0") def this() = this(Identifiable.randomUID("kmeans")) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setK(value: Int): this.type = set(k, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) /** @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setRuns(value: Int): this.type = set(runs, value) - - /** @group setParam */ - def setEpsilon(value: Double): this.type = set(epsilon, value) + @Since("1.5.0") + def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } @@ -191,15 +246,22 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean .setInitializationSteps($(initSteps)) .setMaxIterations($(maxIter)) .setSeed($(seed)) - .setEpsilon($(epsilon)) - .setRuns($(runs)) + .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala new file mode 100644 index 000000000000..830510b1698d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -0,0 +1,811 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} +import org.apache.spark.sql.types.StructType + + +private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter + with HasSeed with HasCheckpointInterval { + + /** + * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * @group param + */ + @Since("1.6.0") + final val k = new IntParam(this, "k", "number of topics (clusters) to infer", + ParamValidators.gt(1)) + + /** @group getParam */ + @Since("1.6.0") + def getK: Int = $(k) + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). + * + * If not set by the user, then docConcentration is set automatically. If set to + * singleton vector [alpha], then alpha is replicated to a vector of length k in fitting. + * Otherwise, the [[docConcentration]] vector must be length k. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val docConcentration = new DoubleArrayParam(this, "docConcentration", + "Concentration parameter (commonly named \"alpha\") for the prior placed on documents'" + + " distributions over topics (\"theta\").", (alpha: Array[Double]) => alpha.forall(_ >= 0.0)) + + /** @group getParam */ + @Since("1.6.0") + def getDocConcentration: Array[Double] = $(docConcentration) + + /** Get docConcentration used by spark.mllib LDA */ + protected def getOldDocConcentration: Vector = { + if (isSet(docConcentration)) { + Vectors.dense(getDocConcentration) + } else { + Vectors.dense(-1.0) + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * If not set by the user, then topicConcentration is set automatically. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val topicConcentration = new DoubleParam(this, "topicConcentration", + "Concentration parameter (commonly named \"beta\" or \"eta\") for the prior placed on topic'" + + " distributions over terms.", ParamValidators.gtEq(0)) + + /** @group getParam */ + @Since("1.6.0") + def getTopicConcentration: Double = $(topicConcentration) + + /** Get topicConcentration used by spark.mllib LDA */ + protected def getOldTopicConcentration: Double = { + if (isSet(topicConcentration)) { + getTopicConcentration + } else { + -1.0 + } + } + + /** Supported values for Param [[optimizer]]. */ + @Since("1.6.0") + final val supportedOptimizers: Array[String] = Array("online", "em") + + /** + * Optimizer or inference algorithm used to estimate the LDA model. + * Currently supported (case-insensitive): + * - "online": Online Variational Bayes (default) + * - "em": Expectation-Maximization + * + * For details, see the following papers: + * - Online LDA: + * Hoffman, Blei and Bach. "Online Learning for Latent Dirichlet Allocation." + * Neural Information Processing Systems, 2010. + * [[http://www.cs.columbia.edu/~blei/papers/HoffmanBleiBach2010b.pdf]] + * - EM: + * Asuncion et al. "On Smoothing and Inference for Topic Models." + * Uncertainty in Artificial Intelligence, 2009. + * [[http://arxiv.org/pdf/1205.2662.pdf]] + * + * @group param + */ + @Since("1.6.0") + final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), + (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + + /** @group getParam */ + @Since("1.6.0") + def getOptimizer: String = $(optimizer) + + /** + * Output column with estimates of the topic mixture distribution for each document (often called + * "theta" in the literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @group param + */ + @Since("1.6.0") + final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" + + " with estimates of the topic mixture distribution for each document (often called \"theta\"" + + " in the literature). Returns a vector of zeros for an empty document.") + + setDefault(topicDistributionCol -> "topicDistribution") + + /** @group getParam */ + @Since("1.6.0") + def getTopicDistributionCol: String = $(topicDistributionCol) + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) + * Default: 1024, following Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" + + " parameter that downweights early iterations. Larger values make early iterations count less.", + ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningOffset: Double = $(learningOffset) + + /** + * Learning rate, set as an exponential decay rate. + * This should be between (0.5, 1.0] to guarantee asymptotic convergence. + * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). + * Default: 0.51, based on Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" + + " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" + + " convergence.", ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningDecay: Double = $(learningDecay) + + /** + * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, + * in range (0, 1]. + * + * Note that this should be adjusted in synch with [[LDA.maxIter]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Note: This is the same as the `miniBatchFraction` parameter in + * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. + * + * Default: 0.05, i.e., 5% of total documents. + * @group param + */ + @Since("1.6.0") + final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" + + " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].", + ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("1.6.0") + def getSubsamplingRate: Double = $(subsamplingRate) + + /** + * Indicates whether the docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. + * Setting this to true will make the model more expressive and fit the training data better. + * Default: false + * @group expertParam + */ + @Since("1.6.0") + final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration", + "Indicates whether the docConcentration (Dirichlet parameter for document-topic" + + " distribution) will be optimized during training.") + + /** @group expertGetParam */ + @Since("1.6.0") + def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) + } + + @Since("1.6.0") + override def validateParams(): Unit = { + if (isSet(docConcentration)) { + if (getDocConcentration.length != 1) { + require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + + s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + + s" length either 1 (scalar) or k (num topics).") + } + getOptimizer match { + case "online" => + require(getDocConcentration.forall(_ >= 0), + "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + + getDocConcentration.mkString(",")) + case "em" => + require(getDocConcentration.forall(_ >= 0), + "For EM optimizer, docConcentration values must be >= 1. Found values: " + + getDocConcentration.mkString(",")) + } + } + if (isSet(topicConcentration)) { + getOptimizer match { + case "online" => + require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + + s" must be >= 0. Found value: $getTopicConcentration") + case "em" => + require(getTopicConcentration >= 0, s"For EM optimizer, topicConcentration" + + s" must be >= 1. Found value: $getTopicConcentration") + } + } + } + + private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + } +} + + +/** + * :: Experimental :: + * Model fitted by [[LDA]]. + * + * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) + * @param sqlContext Used to construct local DataFrames for returning query results + */ +@Since("1.6.0") +@Experimental +sealed abstract class LDAModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val vocabSize: Int, + @Since("1.6.0") @transient protected val sqlContext: SQLContext) + extends Model[LDAModel] with LDAParams with Logging with MLWritable { + + // NOTE to developers: + // This abstraction should contain all important functionality for basic LDA usage. + // Specializations of this class can contain expert-only functionality. + + /** + * Underlying spark.mllib model. + * If this model was produced by Online LDA, then this is the only model representation. + * If this model was produced by EM, then this local representation may be built lazily. + */ + @Since("1.6.0") + protected def oldLocalModel: OldLocalLDAModel + + /** Returns underlying spark.mllib model, which may be local or distributed */ + @Since("1.6.0") + protected def getModel: OldLDAModel + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Transforms the input dataset. + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + */ + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + if ($(topicDistributionCol).nonEmpty) { + val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + } else { + logWarning("LDAModel.transform was called without any output columns. Set an output column" + + " such as topicDistributionCol to produce results.") + dataset + } + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Value for [[docConcentration]] estimated from data. + * If Online LDA was used and [[optimizeDocConcentration]] was set to false, + * then this returns the fixed (given) value for the [[docConcentration]] parameter. + */ + @Since("1.6.0") + def estimatedDocConcentration: Vector = getModel.docConcentration + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by + * the Expectation-Maximization ("em") [[optimizer]], then this method could involve + * collecting a large amount of data to the driver (on the order of vocabSize x k). + */ + @Since("1.6.0") + def topicsMatrix: Matrix = oldLocalModel.topicsMatrix + + /** Indicates whether this instance is of type [[DistributedLDAModel]] */ + @Since("1.6.0") + def isDistributed: Boolean + + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + @Since("1.6.0") + def logLikelihood(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logLikelihood(oldDataset) + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating perplexity + * @return Variational upper bound on log perplexity per token. + */ + @Since("1.6.0") + def logPerplexity(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logPerplexity(oldDataset) + } + + /** + * Return the topics described by their top-weighted terms. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * Default value of 10. + * @return Local DataFrame with one topic per Row, with columns: + * - "topic": IntegerType: topic index + * - "termIndices": ArrayType(IntegerType): term indices, sorted in order of decreasing + * term importance + * - "termWeights": ArrayType(DoubleType): corresponding sorted term weights + */ + @Since("1.6.0") + def describeTopics(maxTermsPerTopic: Int): DataFrame = { + val topics = getModel.describeTopics(maxTermsPerTopic).zipWithIndex.map { + case ((termIndices, termWeights), topic) => + (topic, termIndices.toSeq, termWeights.toSeq) + } + sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + } + + @Since("1.6.0") + def describeTopics(): DataFrame = describeTopics(10) +} + + +/** + * :: Experimental :: + * + * Local (non-distributed) model fitted by [[LDA]]. + * + * This model stores the inferred topics only; it does not store info about the training dataset. + */ +@Since("1.6.0") +@Experimental +class LocalLDAModel private[ml] ( + uid: String, + vocabSize: Int, + @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, + sqlContext: SQLContext) + extends LDAModel(uid, vocabSize, sqlContext) { + + @Since("1.6.0") + override def copy(extra: ParamMap): LocalLDAModel = { + val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] + } + + override protected def getModel: OldLDAModel = oldLocalModel + + @Since("1.6.0") + override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) +} + + +/** + * :: Experimental :: + * + * Distributed model fitted by [[LDA]]. + * This type of model is currently only produced by Expectation-Maximization (EM). + * + * This model stores the inferred topics, the full training dataset, and the topic distribution + * for each training document. + * + * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping + * [[copy()]] cheap. + */ +@Since("1.6.0") +@Experimental +class DistributedLDAModel private[ml] ( + uid: String, + vocabSize: Int, + private val oldDistributedModel: OldDistributedLDAModel, + sqlContext: SQLContext, + private var oldLocalModelOption: Option[OldLocalLDAModel]) + extends LDAModel(uid, vocabSize, sqlContext) { + + override protected def oldLocalModel: OldLocalLDAModel = { + if (oldLocalModelOption.isEmpty) { + oldLocalModelOption = Some(oldDistributedModel.toLocal) + } + oldLocalModelOption.get + } + + override protected def getModel: OldLDAModel = oldDistributedModel + + /** + * Convert this distributed model to a local representation. This discards info about the + * training dataset. + * + * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. + */ + @Since("1.6.0") + def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + + @Since("1.6.0") + override def copy(extra: ParamMap): DistributedLDAModel = { + val copied = + new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) + copyValues(copied, extra).setParent(parent) + copied + } + + @Since("1.6.0") + override def isDistributed: Boolean = true + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) + * + * Notes: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + * - This is computed from the topic distributions computed during training. If you call + * [[logLikelihood()]] on the same training dataset, the topic distributions will be computed + * again, possibly giving different results. + */ + @Since("1.6.0") + lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | Dirichlet hyperparameters) + */ + @Since("1.6.0") + lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) +} + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "term" = "word": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over terms representing some concept + * - "document": one piece of text, corresponding to one row in the input data + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * + * Input data (featuresCol): + * LDA is given a collection of documents as input data, via the featuresCol parameter. + * Each document is specified as a [[Vector]] of length vocabSize, where each entry is the + * count for the corresponding term (word) in the document. Feature transformers such as + * [[org.apache.spark.ml.feature.Tokenizer]] and [[org.apache.spark.ml.feature.CountVectorizer]] + * can be useful for converting text to word count vectors. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Since("1.6.0") +@Experimental +class LDA @Since("1.6.0") ( + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("lda")) + + setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, + learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, + optimizeDocConcentration -> true) + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("1.6.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.6.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Array[Double]): this.type = set(docConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Double): this.type = set(docConcentration, Array(value)) + + /** @group setParam */ + @Since("1.6.0") + def setTopicConcentration(value: Double): this.type = set(topicConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setOptimizer(value: String): this.type = set(optimizer, value) + + /** @group setParam */ + @Since("1.6.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningOffset(value: Double): this.type = set(learningOffset, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningDecay(value: Double): this.type = set(learningDecay, value) + + /** @group setParam */ + @Since("1.6.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + + @Since("1.6.0") + override def copy(extra: ParamMap): LDA = defaultCopy(extra) + + @Since("1.6.0") + override def fit(dataset: DataFrame): LDAModel = { + transformSchema(dataset.schema, logging = true) + val oldLDA = new OldLDA() + .setK($(k)) + .setDocConcentration(getOldDocConcentration) + .setTopicConcentration(getOldTopicConcentration) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setCheckpointInterval($(checkpointInterval)) + .setOptimizer(getOldOptimizer) + // TODO: persist here, or in old LDA? + val oldData = LDA.getOldDataset(dataset, $(featuresCol)) + val oldModel = oldLDA.run(oldData) + val newModel = oldModel match { + case m: OldLocalLDAModel => + new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + case m: OldDistributedLDAModel => + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) + } + copyValues(newModel).setParent(this) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + + +private[clustering] object LDA extends DefaultParamsReadable[LDA] { + + /** Get dataset for spark.mllib LDA */ + def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + dataset + .withColumn("docId", monotonicallyIncreasingId()) + .select("docId", featuresCol) + .map { case Row(docId: Long, features: Vector) => + (docId, features) + } + } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 4a82b77f0edc..bfb70963b151 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -28,35 +28,55 @@ import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: - * Evaluator for binary classification, which expects two input columns: score and label. + * Evaluator for binary classification, which expects two input columns: rawPrediction and label. */ +@Since("1.2.0") @Experimental -class BinaryClassificationEvaluator(override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { +class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { + @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) /** * param for metric name in evaluation + * Default: areaUnderROC * @group param */ - val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") + @Since("1.2.0") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR")) + new Param( + this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams) + } /** @group getParam */ + @Since("1.2.0") def getMetricName: String = $(metricName) /** @group setParam */ + @Since("1.2.0") def setMetricName(value: String): this.type = set(metricName, value) /** @group setParam */ + @Since("1.5.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + + /** + * @group setParam + * @deprecated use [[setRawPredictionCol()]] instead + */ + @deprecated("use setRawPredictionCol instead", "1.5.0") + @Since("1.2.0") def setScoreCol(value: String): this.type = set(rawPredictionCol, value) /** @group setParam */ + @Since("1.2.0") def setLabelCol(value: String): this.type = set(labelCol, value) setDefault(metricName -> "areaUnderROC") + @Since("1.2.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) @@ -69,16 +89,26 @@ class BinaryClassificationEvaluator(override val uid: String) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { - case "areaUnderROC" => - metrics.areaUnderROC() - case "areaUnderPR" => - metrics.areaUnderPR() - case other => - throw new IllegalArgumentException(s"Does not support metric $other.") + case "areaUnderROC" => metrics.areaUnderROC() + case "areaUnderPR" => metrics.areaUnderPR() } metrics.unpersist() metric } + @Since("1.5.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "areaUnderROC" => true + case "areaUnderPR" => true + } + + @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index e56c946a063e..0f22cca3a78d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame @@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame * :: DeveloperApi :: * Abstract class for evaluators that compute metrics from predictions. */ +@Since("1.5.0") @DeveloperApi abstract class Evaluator extends Params { @@ -35,6 +36,7 @@ abstract class Evaluator extends Params { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ + @Since("1.5.0") def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { this.copy(paramMap).evaluate(dataset) } @@ -44,7 +46,17 @@ abstract class Evaluator extends Params { * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ + @Since("1.5.0") def evaluate(dataset: DataFrame): Double + /** + * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) + * or minimized (false). + * A given evaluator may support multiple metrics which may be maximized or minimized. + */ + @Since("1.5.0") + def isLargerBetter: Boolean = true + + @Since("1.5.0") override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 44f779c1908d..c44db0ec595e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -29,10 +29,12 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * Evaluator for multiclass classification, which expects two input columns: score and label. */ +@Since("1.5.0") @Experimental -class MulticlassClassificationEvaluator (override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { +class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) /** @@ -40,6 +42,7 @@ class MulticlassClassificationEvaluator (override val uid: String) * `"weightedPrecision"`, `"weightedRecall"`) * @group param */ + @Since("1.5.0") val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("f1", "precision", "recall", "weightedPrecision", "weightedRecall")) @@ -48,19 +51,24 @@ class MulticlassClassificationEvaluator (override val uid: String) } /** @group getParam */ + @Since("1.5.0") def getMetricName: String = $(metricName) /** @group setParam */ + @Since("1.5.0") def setMetricName(value: String): this.type = set(metricName, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) setDefault(metricName -> "f1") + @Since("1.5.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) @@ -81,5 +89,23 @@ class MulticlassClassificationEvaluator (override val uid: String) metric } + @Since("1.5.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "f1" => true + case "precision" => true + case "recall" => true + case "weightedPrecision" => true + case "weightedRecall" => true + } + + @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 01c000b47514..b6b25ecd01b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -17,22 +17,25 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, FloatType} /** * :: Experimental :: * Evaluator for regression, which expects two input columns: prediction and label. */ +@Since("1.4.0") @Experimental -final class RegressionEvaluator(override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { +final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) /** @@ -43,47 +46,73 @@ final class RegressionEvaluator(override val uid: String) * we take and output the negative of this metric. * @group param */ + @Since("1.4.0") val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) } /** @group getParam */ + @Since("1.4.0") def getMetricName: String = $(metricName) /** @group setParam */ + @Since("1.4.0") def setMetricName(value: String): this.type = set(metricName, value) /** @group setParam */ + @Since("1.4.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.4.0") def setLabelCol(value: String): this.type = set(labelCol, value) setDefault(metricName -> "rmse") + @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + val predictionColName = $(predictionCol) + val predictionType = schema($(predictionCol)).dataType + require(predictionType == FloatType || predictionType == DoubleType, + s"Prediction column $predictionColName must be of type float or double, " + + s" but not $predictionType") + val labelColName = $(labelCol) + val labelType = schema($(labelCol)).dataType + require(labelType == FloatType || labelType == DoubleType, + s"Label column $labelColName must be of type float or double, but not $labelType") - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + val predictionAndLabels = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) .map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { - case "rmse" => - -metrics.rootMeanSquaredError - case "mse" => - -metrics.meanSquaredError - case "r2" => - metrics.r2 - case "mae" => - -metrics.meanAbsoluteError + case "rmse" => metrics.rootMeanSquaredError + case "mse" => metrics.meanSquaredError + case "r2" => metrics.r2 + case "mae" => metrics.meanAbsoluteError } metric } + @Since("1.4.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "rmse" => false + case "mse" => false + case "r2" => true + case "mae" => false + } + + @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 46314854d5e3..63c06581482e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -41,6 +41,7 @@ final class Binarizer(override val uid: String) * Param for threshold used to binarize continuous features. * The features greater than the threshold, will be binarized to 1.0. * The features equal to or less than the threshold, will be binarized to 0.0. + * Default: 0.0 * @group param */ val threshold: DoubleParam = @@ -86,3 +87,10 @@ final class Binarizer(override val uid: String) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } + +@Since("1.6.0") +object Binarizer extends DefaultParamsReadable[Binarizer] { + + @Since("1.6.0") + override def load(path: String): Binarizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc355..324353a96afb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -90,12 +90,15 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } -private[feature] object Bucketizer { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ - def checkSplits(splits: Array[Double]): Boolean = { + private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false } else { @@ -113,7 +116,7 @@ private[feature] object Bucketizer { * Binary searching in several buckets to place each data point. * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { if (feature == splits.last) { splits.length - 2 } else { @@ -132,4 +135,7 @@ private[feature] object Bucketizer { } } } + + @Since("1.6.0") + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala new file mode 100644 index 000000000000..dfec03828f4b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute.{AttributeGroup, _} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]]. + */ +private[feature] trait ChiSqSelectorParams extends Params + with HasFeaturesCol with HasOutputCol with HasLabelCol { + + /** + * Number of features that selector will select (ordered by statistic value descending). If the + * number of features is < numTopFeatures, then this will select all features. The default value + * of numTopFeatures is 50. + * @group param + */ + final val numTopFeatures = new IntParam(this, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value descending. If the" + + " number of features is < numTopFeatures, then this will select all features.", + ParamValidators.gtEq(1)) + setDefault(numTopFeatures -> 50) + + /** @group getParam */ + def getNumTopFeatures: Int = $(numTopFeatures) +} + +/** + * :: Experimental :: + * Chi-Squared feature selection, which selects categorical features to use for predicting a + * categorical label. + */ +@Experimental +final class ChiSqSelector(override val uid: String) + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("chiSqSelector")) + + /** @group setParam */ + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def fit(dataset: DataFrame): ChiSqSelectorModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(labelCol), $(featuresCol)).map { + case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) + copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) +} + +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[ChiSqSelector]]. + */ +@Experimental +final class ChiSqSelectorModel private[ml] ( + override val uid: String, + private val chiSqSelector: feature.ChiSqSelectorModel) + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val newField = transformedSchema.last + val selector = udf { chiSqSelector.transform _ } + dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + val newField = prepOutputField(schema) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + */ + private def prepOutputField(schema: StructType): StructField = { + val selector = chiSqSelector.selectedFeatures.toSet + val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) + } else { + Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr) + } + val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) + newAttributeGroup.toStructField() + } + + override def copy(extra: ParamMap): ChiSqSelectorModel = { + val copied = new ChiSqSelectorModel(uid, chiSqSelector) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala new file mode 100644 index 000000000000..b9e2144c0ad4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. + */ +private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Max size of the vocabulary. + * CountVectorizer will build a vocabulary that only considers the top + * vocabSize terms ordered by term frequency across the corpus. + * + * Default: 2^18^ + * @group param + */ + val vocabSize: IntParam = + new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0)) + + /** @group getParam */ + def getVocabSize: Int = $(vocabSize) + + /** + * Specifies the minimum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer >= 1, this specifies the number of documents the term must appear in; + * if this is a double in [0,1), then this specifies the fraction of documents. + * + * Default: 1 + * @group param + */ + val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMinDF: Double = $(minDF) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + /** + * Filter to ignore rare words in a document. For each document, terms with + * frequency/count less than the given threshold are ignored. + * If this is an integer >= 1, then this specifies a count (of times the term must appear + * in the document); + * if this is a double in [0,1), then this specifies a fraction (out of the document's token + * count). + * + * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not + * affect fitting. + * + * Default: 1 + * @group param + */ + val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given threshold are" + + " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" + + " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" + + " of the document's token count). Note that the parameter is only used in transform of" + + " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0)) + + setDefault(minTF -> 1) + + /** @group getParam */ + def getMinTF: Double = $(minTF) +} + +/** + * :: Experimental :: + * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. + */ +@Experimental +class CountVectorizer(override val uid: String) + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("cntVec")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVocabSize(value: Int): this.type = set(vocabSize, value) + + /** @group setParam */ + def setMinDF(value: Double): this.type = set(minDF, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + setDefault(vocabSize -> (1 << 18), minDF -> 1) + + override def fit(dataset: DataFrame): CountVectorizerModel = { + transformSchema(dataset.schema, logging = true) + val vocSize = $(vocabSize) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val minDf = if ($(minDF) >= 1.0) { + $(minDF) + } else { + $(minDF) * input.cache().count() + } + val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val wc = new OpenHashMap[String, Long] + tokens.foreach { w => + wc.changeValue(w, 1L, _ + 1L) + } + wc.map { case (word, count) => (word, (count, 1)) } + }.reduceByKey { case ((wc1, df1), (wc2, df2)) => + (wc1 + wc2, df1 + df2) + }.filter { case (word, (wc, df)) => + df >= minDf + }.map { case (word, (count, dfCount)) => + (word, count) + }.cache() + val fullVocabSize = wordCounts.count() + val vocab: Array[String] = { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocSize) + } + tmpSortedWC.map(_._1) + } + + require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") + copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) +} + +@Since("1.6.0") +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) +} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { + + import CountVectorizerModel._ + + def this(vocabulary: Array[String]) = { + this(Identifiable.randomUID("cntVecModel"), vocabulary) + set(vocabSize, vocabulary.length) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ + private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None + + override def transform(dataset: DataFrame): DataFrame = { + if (broadcastDict.isEmpty) { + val dict = vocabulary.zipWithIndex.toMap + broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) + } + val dictBr = broadcastDict.get + val minTf = $(minTF) + val vectorizer = udf { (document: Seq[String]) => + val termCounts = new OpenHashMap[Int, Double] + var tokenCount = 0L + document.foreach { term => + dictBr.value.get(term) match { + case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) + case None => // ignore terms not in the vocabulary + } + tokenCount += 1 + } + val effectiveMinTF = if (minTf >= 1.0) { + minTf + } else { + tokenCount * minTf + } + Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq) + } + dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) + copyValues(copied, extra) + } + + @Since("1.6.0") + override def write: MLWriter = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { + + private val className = classOf[CountVectorizerModel].getName + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala deleted file mode 100644 index 6b77de89a033..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.feature - -import scala.collection.mutable - -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} - -/** - * :: Experimental :: - * Converts a text document to a sparse vector of token counts. - * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. - */ -@Experimental -class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) - extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { - - def this(vocabulary: Array[String]) = - this(Identifiable.randomUID("cntVec"), vocabulary) - - /** - * Corpus-specific filter to ignore scarce words in a document. For each document, terms with - * frequency (count) less than the given threshold are ignored. - * Default: 1 - * @group param - */ - val minTermFreq: IntParam = new IntParam(this, "minTermFreq", - "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + - "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) - - /** @group setParam */ - def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) - - /** @group getParam */ - def getMinTermFreq: Int = $(minTermFreq) - - setDefault(minTermFreq -> 1) - - override protected def createTransformFunc: Seq[String] => Vector = { - val dict = vocabulary.zipWithIndex.toMap - document => - val termCounts = mutable.HashMap.empty[Int, Double] - document.foreach { term => - dict.get(term) match { - case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) - case None => // ignore terms not in the vocabulary - } - } - Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) - } - - override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") - } - - override protected def outputDataType: DataType = new VectorUDT() - - override def copy(extra: ParamMap): CountVectorizerModel = { - val copied = new CountVectorizerModel(uid, vocabulary) - copyValues(copied, extra) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 228347635c92..6bed72164a1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.types.DataType @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -70,3 +70,10 @@ class DCT(override val uid: String) override protected def outputDataType: DataType = new VectorUDT } + +@Since("1.6.0") +object DCT extends DefaultParamsReadable[DCT] { + + @Since("1.6.0") + override def load(path: String): DCT = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 319d23e46cef..9e15835429a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * Maps a sequence of terms to their term frequencies using the hashing trick. */ @Experimental -class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { +class HashingTF(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,3 +78,10 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } + +@Since("1.6.0") +object HashingTF extends DefaultParamsReadable[HashingTF] { + + @Since("1.6.0") + override def load(path: String): HashingTF = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde80810580..f7b0f29a27c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -35,6 +37,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** * The minimum of documents in which a term should appear. + * Default: 0 * @group param */ final val minDocFreq = new IntParam( @@ -59,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -86,6 +90,13 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def copy(extra: ParamMap): IDF = defaultCopy(extra) } +@Since("1.6.0") +object IDF extends DefaultParamsReadable[IDF] { + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[IDF]]. @@ -94,7 +105,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with MLWritable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -114,6 +127,52 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: MLWriter = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends MLReadable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends MLReader[IDFModel] { + + private val className = classOf[IDFModel].getName + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala new file mode 100644 index 000000000000..12176757aee3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.mllib.linalg.Vector + +/** + * Class that represents an instance of weighted data point with label and features. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[ml] case class Instance(label: Double, weight: Double, features: Vector) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 000000000000..2181119f04a5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Since("1.6.0") +@Experimental +class Interaction @Since("1.6.0") (override val uid: String) extends Transformer + with HasInputCols with HasOutputCol with DefaultParamsWritable { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + @Since("1.6.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("1.6.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.decodeStructField(f, preserveName = true) + if (attr == UnresolvedAttribute) { + encodedFeatureAttrs(Seq(NumericAttribute.defaultAttr.withName(f.name)), None) + } else if (!attr.name.isDefined) { + encodedFeatureAttrs(Seq(attr.withName(f.name)), None) + } else { + encodedFeatureAttrs(Seq(attr), None) + } + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + @Since("1.6.0") + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + @Since("1.6.0") + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +@Since("1.6.0") +object Interaction extends DefaultParamsReadable[Interaction] { + + @Since("1.6.0") + override def load(path: String): Interaction = super.load(path) +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b30adf3df48d..c2866f5eceff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -41,6 +44,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val min: DoubleParam = new DoubleParam(this, "min", "lower bound of the output feature range") + /** @group getParam */ + def getMin: Double = $(min) + /** * upper bound after transformation, shared by all features * Default: 1.0 @@ -49,6 +55,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val max: DoubleParam = new DoubleParam(this, "max", "upper bound of the output feature range") + /** @group getParam */ + def getMax: Double = $(max) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType @@ -79,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -111,10 +120,20 @@ class MinMaxScaler(override val uid: String) override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) } +@Since("1.6.0") +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ @Experimental @@ -122,7 +141,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -136,7 +157,6 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray @@ -165,6 +185,48 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { + + private val className = classOf[MinMaxScalerModel].getName + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 8de10eb51f92..65414ecbefbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -67,3 +67,10 @@ class NGram(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, false) } + +@Since("1.6.0") +object NGram extends DefaultParamsReadable[NGram] { + + @Since("1.6.0") + override def load(path: String): NGram = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 8282e5ffa17f..c2d514fd9629 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType * Normalize a vector to have unit norm using the given p-norm. */ @Experimental -class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { +class Normalizer(override val uid: String) + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,3 +57,10 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect override protected def outputDataType: DataType = new VectorUDT() } + +@Since("1.6.0") +object Normalizer extends DefaultParamsReadable[Normalizer] { + + @Since("1.6.0") + override def load(path: String): Normalizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 9c60d4084ec4..d70164eaf022 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -166,3 +166,10 @@ class OneHotEncoder(override val uid: String) extends Transformer override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } + +@Since("1.6.0") +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { + + @Since("1.6.0") + override def load(path: String): OneHotEncoder = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf30..53d33ea2b8f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -70,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -86,15 +89,27 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + val pc: DenseMatrix, + val explainedVariance: DenseVector) + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -109,6 +124,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -124,7 +140,49 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + val copied = new PCAModel(uid, pc, explainedVariance) + copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(pc: DenseMatrix, explainedVariance: DenseVector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.pc, instance.explainedVariance) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(pc: DenseMatrix, explainedVariance: DenseVector) = + sqlContext.read.parquet(dataPath) + .select("pc", "explainedVariance") + .head() + val model = new PCAModel(metadata.uid, pc, explainedVariance) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d85e468562d4..08610593fadd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -77,7 +77,8 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -private[feature] object PolynomialExpansion { +@Since("1.6.0") +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -169,11 +170,14 @@ private[feature] object PolynomialExpansion { new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } - def expand(v: Vector, degree: Int): Vector = { + private[feature] def expand(v: Vector, degree: Int): Vector = { v match { case dv: DenseVector => expand(dv, degree) case sv: SparseVector => expand(sv, degree) case _ => throw new IllegalArgumentException } } + + @Since("1.6.0") + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala new file mode 100644 index 000000000000..7bf67c6325a3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Params for [[QuantileDiscretizer]]. + */ +private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { + + /** + * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * be >= 2. + * default: 2 + * @group param + */ + val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2.", + ParamValidators.gtEq(2)) + setDefault(numBuckets -> 2) + + /** @group getParam */ + def getNumBuckets: Int = getOrDefault(numBuckets) +} + +/** + * :: Experimental :: + * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned + * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it + * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + * covering all real values. This attempts to find numBuckets partitions based on a sample of data, + * but it may find fewer depending on the data sample values. + */ +@Experimental +final class QuantileDiscretizer(override val uid: String) + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("quantileDiscretizer")) + + /** @group setParam */ + def setNumBuckets(value: Int): this.type = set(numBuckets, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + val inputFields = schema.fields + require(inputFields.forall(_.name != $(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + override def fit(dataset: DataFrame): Bucketizer = { + val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) + .map { case Row(feature: Double) => feature } + val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) + val splits = QuantileDiscretizer.getSplits(candidates) + val bucketizer = new Bucketizer(uid).setSplits(splits) + copyValues(bucketizer) + } + + override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) +} + +@Since("1.6.0") +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + /** + * Sampling from the given dataset to collect quantile statistics. + */ + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + val totalSamples = dataset.count() + require(totalSamples > 0, + "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") + val requiredSamples = math.max(numBins * numBins, 10000) + val fraction = math.min(requiredSamples / dataset.count(), 1.0) + dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + } + + /** + * Compute split points with respect to the sample distribution. + */ + private[feature] + def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { + val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) + val possibleSplits = valueCounts.length - 1 + if (possibleSplits <= numSplits) { + valueCounts.dropRight(1).map(_._1) + } else { + val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) + val splitsBuilder = mutable.ArrayBuilder.make[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. If `currentCount` is closest value to + // `targetCount`, then current value is a split threshold. After finding a split threshold, + // `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount makes the gap between currentCount and + // targetCount smaller, previous value is a split threshold. + if (previousGap < currentGap) { + splitsBuilder += valueCounts(index - 1)._1 + targetCount += stride + } + index += 1 + } + splitsBuilder.result() + } + } + + /** + * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as + * needed, and adding a default split value of 0 if no good candidates are found. + */ + private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { + val effectiveValues = if (candidates.size != 0) { + if (candidates.head == Double.NegativeInfinity + && candidates.last == Double.PositiveInfinity) { + candidates.drop(1).dropRight(1) + } else if (candidates.head == Double.NegativeInfinity) { + candidates.drop(1) + } else if (candidates.last == Double.PositiveInfinity) { + candidates.dropRight(1) + } else { + candidates + } + } else { + candidates + } + + if (effectiveValues.size == 0) { + Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) + } else { + Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) + } + } + + @Since("1.6.0") + override def load(path: String): QuantileDiscretizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d1726917e451..5c43a41bee3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -19,27 +19,21 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** * Base trait for [[RFormula]] and [[RFormulaModel]]. */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) @@ -49,8 +43,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula - * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see + * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { @@ -63,62 +57,89 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R */ val formula: Param[String] = new Param(this, "formula", "R model formula") - private var parsedFormula: Option[ParsedRFormula] = None - /** * Sets the formula to use for this transformer. Must be called before use. * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") */ - def setFormula(value: String): this.type = { - parsedFormula = Some(RFormulaParser.parse(value)) - set(formula, value) - this - } + def setFormula(value: String): this.type = set(formula, value) /** @group getParam */ def getFormula: String = $(formula) + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - parsedFormula.get.hasIntercept + require(isDefined(formula), "Formula must be defined first.") + RFormulaParser.parse($(formula)).hasIntercept } override def fit(dataset: DataFrame): RFormulaModel = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - val resolvedFormula = parsedFormula.get.resolve(dataset.schema) - // StringType terms and terms representing interactions need to be encoded before assembly. - // TODO(ekl) add support for feature interactions + require(isDefined(formula), "Formula must be defined first.") + val parsedFormula = RFormulaParser.parse($(formula)) + val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + + val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() - val takenNames = mutable.Set(dataset.columns: _*) - val encodedTerms = resolvedFormula.terms.map { term => + def tmpColumn(category: String): String = { + val col = Identifiable.randomUID(category) + tempColumns += col + col + } + + // First we index each string column referenced by the input terms. + val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => dataset.schema(term) match { case column if column.dataType == StringType => - val indexCol = term + "_idx_" + uid - val encodedCol = { - var tmp = term - while (takenNames.contains(tmp)) { - tmp += "_" - } - tmp - } - takenNames.add(indexCol) - takenNames.add(encodedCol) - encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) - encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) - tempColumns += indexCol - tempColumns += encodedCol - encodedCol + val indexCol = tmpColumn("stridx") + encoderStages += new StringIndexer() + .setInputCol(term) + .setOutputCol(indexCol) + (term, indexCol) case _ => - term + (term, term) } + }.toMap + + // Then we handle one-hot encoding and interactions between terms. + val encodedTerms = resolvedFormula.terms.map { + case Seq(term) if dataset.schema(term).dataType == StringType => + val encodedCol = tmpColumn("onehot") + encoderStages += new OneHotEncoder() + .setInputCol(indexed(term)) + .setOutputCol(encodedCol) + prefixesToRewrite(encodedCol + "_") = term + "_" + encodedCol + case Seq(term) => + term + case terms => + val interactionCol = tmpColumn("interaction") + encoderStages += new Interaction() + .setInputCols(terms.map(indexed).toArray) + .setOutputCol(interactionCol) + prefixesToRewrite(interactionCol + "_") = "" + interactionCol } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) + + if (dataset.schema.fieldNames.contains(resolvedFormula.label) && + dataset.schema(resolvedFormula.label).dataType == StringType) { + encoderStages += new StringIndexer() + .setInputCol(resolvedFormula.label) + .setOutputCol($(labelCol)) + } + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } @@ -135,7 +156,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } /** @@ -159,7 +180,7 @@ class RFormulaModel private[feature]( override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) - if (hasLabelCol(schema)) { + if (hasLabelCol(withFeatures)) { withFeatures } else if (schema.exists(_.name == resolvedFormula.label)) { val nullable = schema(resolvedFormula.label).dataType match { @@ -177,7 +198,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label @@ -224,3 +245,53 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } + +/** + * Utility transformer that rewrites Vector attribute names via prefix replacement. For example, + * it can rewrite attribute names starting with 'foo_' to start with 'bar_' instead. + * + * @param vectorCol name of the vector column to rewrite. + * @param prefixesToRewrite the map of string prefixes to their replacement values. Each attribute + * name defined in vectorCol will be checked against the keys of this + * map. When a key prefixes a name, the matching prefix will be replaced + * by the value in the map. + */ +private class VectorAttributeRewriter( + vectorCol: String, + prefixesToRewrite: Map[String, String]) + extends Transformer { + + override val uid = Identifiable.randomUID("vectorAttrRewriter") + + override def transform(dataset: DataFrame): DataFrame = { + val metadata = { + val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) + val attrs = group.attributes.get.map { attr => + if (attr.name.isDefined) { + val name = attr.name.get + val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) } + if (replacement.nonEmpty) { + val (k, v) = replacement.headOption.get + attr.withName(v + name.stripPrefix(k)) + } else { + attr + } + } else { + attr + } + } + new AttributeGroup(vectorCol, attrs).toMetadata() + } + val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) + val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) + dataset.select((otherCols :+ rewrittenCol): _*) + } + + override def transformSchema(schema: StructType): StructType = { + StructType( + schema.fields.filter(_.name != vectorCol) ++ + schema.fields.filter(_.name == vectorCol)) + } + + override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 1ca3b92a7d92..4079b387e183 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.util.parsing.combinator.RegexParsers import org.apache.spark.mllib.linalg.VectorUDT @@ -31,27 +32,35 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - var includedTerms = Seq[String]() + val dotTerms = expandDot(schema) + var includedTerms = Seq[Seq[String]]() terms.foreach { + case col: ColumnRef => + includedTerms :+= Seq(col.value) + case ColumnInteraction(cols) => + includedTerms ++= expandInteraction(schema, cols) case Dot => - includedTerms ++= simpleTypes(schema).filter(_ != label.value) - case ColumnRef(value) => - includedTerms :+= value + includedTerms ++= dotTerms.map(Seq(_)) case Deletion(term: Term) => term match { - case ColumnRef(value) => - includedTerms = includedTerms.filter(_ != value) + case inner: ColumnRef => + includedTerms = includedTerms.filter(_ != Seq(inner.value)) + case ColumnInteraction(cols) => + val fromInteraction = expandInteraction(schema, cols).map(_.toSet) + includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet)) case Dot => // e.g. "- .", which removes all first-order terms - val fromSchema = simpleTypes(schema) - includedTerms = includedTerms.filter(fromSchema.contains(_)) + includedTerms = includedTerms.filter { + case Seq(t) => !dotTerms.contains(t) + case _ => true + } case _: Deletion => - assert(false, "Deletion terms cannot be nested") + throw new RuntimeException("Deletion terms cannot be nested") case _: Intercept => } case _: Intercept => } - ResolvedRFormula(label.value, includedTerms.distinct) + ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept) } /** Whether this formula specifies fitting with an intercept term. */ @@ -67,19 +76,54 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { intercept } + // expands the Dot operators in interaction terms + private def expandInteraction( + schema: StructType, terms: Seq[InteractableTerm]): Seq[Seq[String]] = { + if (terms.isEmpty) { + return Seq(Nil) + } + + val rest = expandInteraction(schema, terms.tail) + val validInteractions = (terms.head match { + case Dot => + expandDot(schema).flatMap { t => + rest.map { r => + Seq(t) ++ r + } + } + case ColumnRef(value) => + rest.map(Seq(value) ++ _) + }).map(_.distinct) + + // Deduplicates feature interactions, for example, a:b is the same as b:a. + var seen = mutable.Set[Set[String]]() + validInteractions.flatMap { + case t if seen.contains(t.toSet) => + None + case t => + seen += t.toSet + Some(t) + }.sortBy(_.length) + } + // the dot operator excludes complex column types - private def simpleTypes(schema: StructType): Seq[String] = { + private def expandDot(schema: StructType): Seq[String] = { schema.fields.filter(_.dataType match { case _: NumericType | StringType | BooleanType | _: VectorUDT => true case _ => false - }).map(_.name) + }).map(_.name).filter(_ != label.value) } } /** * Represents a fully evaluated and simplified R formula. + * @param label the column name of the R formula label (response variable). + * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs + * of column names; non-interaction terms as length 1 Seqs. + * @param hasIntercept whether the formula specifies fitting with an intercept. */ -private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) +private[ml] case class ResolvedRFormula( + label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) /** * R formula terms. See the R formula docs here for more information: @@ -87,11 +131,17 @@ private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) */ private[ml] sealed trait Term +/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */ +private[ml] sealed trait InteractableTerm extends Term + /* R formula reference to all available columns, e.g. "." in a formula */ -private[ml] case object Dot extends Term +private[ml] case object Dot extends InteractableTerm /* R formula reference to a column, e.g. "+ Species" in a formula */ -private[ml] case class ColumnRef(value: String) extends Term +private[ml] case class ColumnRef(value: String) extends InteractableTerm + +/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ +private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term @@ -100,25 +150,30 @@ private[ml] case class Intercept(enabled: Boolean) extends Term private[ml] case class Deletion(term: Term) extends Term /** - * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'. */ private[ml] object RFormulaParser extends RegexParsers { - def intercept: Parser[Intercept] = + private val intercept: Parser[Intercept] = "([01])".r ^^ { case a => Intercept(a == "1") } - def columnRef: Parser[ColumnRef] = + private val columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } - def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot } + + private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":") + + private val term: Parser[Term] = intercept | + interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef - def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { case op ~ list => list.foldLeft(List(op)) { case (left, "+" ~ right) => left ++ Seq(right) case (left, "-" ~ right) => left ++ Seq(Deletion(right)) } } - def formula: Parser[ParsedRFormula] = + private val formula: Parser[ParsedRFormula] = (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 000000000000..c09f4d076c96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transformations which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' + * where '__THIS__' represents the underlying table of the input dataset. + * The select clause specifies the fields, constants, and expressions to display in + * the output, it can be any select clause that Spark SQL supports. Users can also + * use Spark SQL built-in function and UDFs to operate on these selected columns. + * For example, [[SQLTransformer]] supports statements like: + * - SELECT a, a + b AS a_b FROM __THIS__ + * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 + * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b + */ +@Experimental +@Since("1.6.0") +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + @Since("1.6.0") + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + @Since("1.6.0") + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + @Since("1.6.0") + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + @Since("1.6.0") + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} + +@Since("1.6.0") +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { + + @Since("1.6.0") + override def load(path: String): SQLTransformer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db3e..d76a9c6275e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -34,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -57,12 +69,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -80,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -96,21 +106,28 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } +@Since("1.6.0") +object StandardScaler extends DefaultParamsReadable[StandardScaler] { + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { - - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std + val std: Vector, + val mean: Vector) + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean + import StandardScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -120,6 +137,7 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -135,7 +153,49 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + val copied = new StandardScalerModel(uid, std, mean) + copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends MLReadable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { + + private case class Data(std: Vector, mean: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { + + private val className = classOf[StandardScalerModel].getName + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f..318808596dc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,26 +17,26 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * stop words list */ -private object StopWords { +private[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -86,7 +86,7 @@ private object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -98,9 +98,10 @@ class StopWordsRemover(override val uid: String) /** * the stop words set to be filtered out + * Default: [[StopWords.English]] * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) @@ -110,6 +111,7 @@ class StopWordsRemover(override val uid: String) /** * whether to do a case sensitive comparison over the stop words + * Default: false * @group param */ val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", @@ -121,7 +123,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + setDefault(stopWords -> StopWords.English, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) @@ -153,3 +155,10 @@ class StopWordsRemover(override val uid: String) override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) } + +@Since("1.6.0") +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { + + @Since("1.6.0") + override def load(path: String): StopWordsRemover = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ebfa97253235..5c40c35eeaa4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,23 +17,25 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -58,20 +60,25 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -88,17 +95,32 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } +@Since("1.6.0") +object StringIndexer extends DefaultParamsReadable[StringIndexer] { + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * * NOTE: During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. + * + * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, - labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { + + import StringIndexerModel._ + + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length @@ -111,6 +133,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -128,15 +154,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } - val outputColName = $(outputCol) + val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) + .withName($(inputCol)).withValues(labels).toMetadata() + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), + indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { @@ -150,36 +185,68 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends MLReadable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } } + + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { + + private val className = classOf[StringIndexerModel].getName + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping is either from the ML attributes of the input column, + * or from user-supplied labels (which take precedence over ML attributes). + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( - override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { +class IndexToString private[ml] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -187,32 +254,23 @@ class StringIndexerInverse private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. The default value is an empty array, - * but the empty array is ignored and column metadata used instead. - * @group setParam - */ + /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) /** - * Param for array of labels. - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. + * Optional param for array of labels specifying index-string mapping. + * + * Default: Empty array, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", - "array of labels, if not provided metadata from inputCol is used instead.") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") setDefault(labels, Array.empty[String]) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. - * @group getParam - */ + /** @group getParam */ final def getLabels: Array[String] = $(labels) - /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType @@ -223,8 +281,7 @@ class StringIndexerInverse private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } @@ -239,7 +296,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -250,7 +307,14 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } + +@Since("1.6.0") +object IndexToString extends DefaultParamsReadable[IndexToString] { + + @Since("1.6.0") + override def load(path: String): IndexToString = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 248288ca73e9..8ad7bbedaab5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * @see [[RegexTokenizer]] */ @Experimental -class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { +class Tokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -47,6 +48,13 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } +@Since("1.6.0") +object Tokenizer extends DefaultParamsReadable[Tokenizer] { + + @Since("1.6.0") + override def load(path: String): Tokenizer = super.load(path) +} + /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split @@ -56,7 +64,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -100,10 +108,25 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + /** + * Indicates whether to convert all characters to lowercase before tokenizing. + * Default: true + * @group param + */ + final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", + "whether to convert all characters to lowercase before tokenizing.") - override protected def createTransformFunc: String => Seq[String] = { str => + /** @group setParam */ + def setToLowercase(value: Boolean): this.type = set(toLowercase, value) + + /** @group getParam */ + def getToLowercase: Boolean = $(toLowercase) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) + + override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + val str = if ($(toLowercase)) originStr.toLowerCase() else originStr val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) @@ -117,3 +140,10 @@ class RegexTokenizer(override val uid: String) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } + +@Since("1.6.0") +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { + + @Since("1.6.0") + override def load(path: String): RegexTokenizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 086917fa680f..801096fed27b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() @@ -122,7 +124,11 @@ class VectorAssembler(override val uid: String) override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } -private object VectorAssembler { +@Since("1.6.0") +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + + @Since("1.6.0") + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5f..a637a6f2881d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,12 +22,14 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.udf @@ -43,6 +45,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Must be >= 2. * * (default = 20) + * @group param */ val maxCategories = new IntParam(this, "maxCategories", "Threshold for the number of values a categorical feature can take (>= 2)." + @@ -92,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -135,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -145,7 +152,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -251,7 +258,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -341,7 +350,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { @@ -405,6 +414,50 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala new file mode 100644 index 000000000000..5410a50bc2e4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * This class takes a feature vector and outputs a new feature vector with a subarray of the + * original features. + * + * The subset of features can be specified with either indices ([[setIndices()]]) + * or names ([[setNames()]]). At least one feature must be selected. Duplicate features + * are not allowed, so there can be no overlap between selected indices and names. + * + * The output vector will order features with the selected indices first (in the order given), + * followed by the selected names (in the order given). + */ +@Experimental +final class VectorSlicer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("vectorSlicer")) + + /** + * An array of indices to select features from a vector column. + * There can be no overlap with [[names]]. + * Default: Empty array + * @group param + */ + val indices = new IntArrayParam(this, "indices", + "An array of indices to select features from a vector column." + + " There can be no overlap with names.", VectorSlicer.validIndices) + + setDefault(indices -> Array.empty[Int]) + + /** @group getParam */ + def getIndices: Array[Int] = $(indices) + + /** @group setParam */ + def setIndices(value: Array[Int]): this.type = set(indices, value) + + /** + * An array of feature names to select features from a vector column. + * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s. + * There can be no overlap with [[indices]]. + * Default: Empty Array + * @group param + */ + val names = new StringArrayParam(this, "names", + "An array of feature names to select features from a vector column." + + " There can be no overlap with indices.", VectorSlicer.validNames) + + setDefault(names -> Array.empty[String]) + + /** @group getParam */ + def getNames: Array[String] = $(names) + + /** @group setParam */ + def setNames(value: Array[String]): this.type = set(names, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def validateParams(): Unit = { + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") + } + + override def transform(dataset: DataFrame): DataFrame = { + // Validity checks + transformSchema(dataset.schema) + val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) + inputAttr.numAttributes.foreach { numFeatures => + val maxIndex = $(indices).max + require(maxIndex < numFeatures, + s"Selected feature index $maxIndex invalid for only $numFeatures input features.") + } + + // Prepare output attributes + val inds = getSelectedFeatureIndices(dataset.schema) + val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs => + inds.map(index => attrs(index)) + } + val outputAttr = selectedAttrs match { + case Some(attrs) => new AttributeGroup($(outputCol), attrs) + case None => new AttributeGroup($(outputCol), inds.length) + } + + // Select features + val slicer = udf { vec: Vector => + vec match { + case features: DenseVector => Vectors.dense(inds.map(features.apply)) + case features: SparseVector => features.slice(inds) + } + } + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) + } + + /** Get the feature indices in order: indices, names */ + private def getSelectedFeatureIndices(schema: StructType): Array[Int] = { + val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names)) + val indFeatures = $(indices) + val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length + lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" + + s" sets of features, but they overlap." + + s" indices: ${indFeatures.mkString("[", ",", "]")}." + + s" names: " + + nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]") + require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg) + indFeatures ++ nameFeatures + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") + } + val numFeaturesSelected = $(indices).length + $(names).length + val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected) + val outputFields = schema.fields :+ outputAttr.toStructField() + StructType(outputFields) + } + + override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) +} + +@Since("1.6.0") +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { + + /** Return true if given feature indices are valid */ + private[feature] def validIndices(indices: Array[Int]): Boolean = { + if (indices.isEmpty) { + true + } else { + indices.length == indices.distinct.length && indices.forall(_ >= 0) + } + } + + /** Return true if given feature names are valid */ + private[feature] def validNames(names: Array[String]): Boolean = { + names.forall(_.nonEmpty) && names.length == names.distinct.length + } + + @Since("1.6.0") + override def load(path: String): VectorSlicer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 6ea659095630..f105a983a34f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -37,6 +39,7 @@ private[feature] trait Word2VecBase extends Params /** * The dimension of the code that you want to transform from words. + * Default: 100 * @group param */ final val vectorSize = new IntParam( @@ -46,8 +49,20 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) default 5. + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. + * Default: 1 * @group param */ final val numPartitions = new IntParam( @@ -60,6 +75,7 @@ private[feature] trait Word2VecBase extends Params /** * The minimum number of times a token must appear to be included in the word2vec model's * vocabulary. + * Default: 5 * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + @@ -87,7 +103,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -100,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group expertSetParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -125,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } @@ -136,6 +157,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -143,8 +171,43 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] @Experimental class Word2VecModel private[ml] ( override val uid: String, - wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + @transient private val wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + @transient lazy val getVectors: DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + sc.parallelize(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word. + * Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word. + */ + def findSynonyms(word: String, num: Int): DataFrame = { + findSynonyms(wordVectors.transform(word), num) + } + + /** + * Find "num" number of words closest to similarity to the given vector representation + * of the word. Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word vector. + */ + def findSynonyms(word: Vector, num: Int): DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + } /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -158,22 +221,23 @@ class Word2VecModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros($(vectorSize)) - val model = bWordVectors.value.getVectors - for (word <- sentence) { - if (model.contains(word)) { - axpy(1.0, bWordVectors.value.transform(word), cum) - } else { - // pass words which not belong to model + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) } } - scal(1.0 / sentence.size, cum) - cum + BLAS.scal(1.0 / sentence.size, sum) + sum } } dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) @@ -185,6 +249,51 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 000000000000..7a35f2d448f9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.ml.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
      + * 
      + *   import java.util.Arrays;
      + *
      + *   import org.apache.spark.api.java.JavaRDD;
      + *   import static org.apache.spark.sql.types.DataTypes.*;
      + *   import org.apache.spark.sql.types.StructType;
      + *   import org.apache.spark.sql.DataFrame;
      + *   import org.apache.spark.sql.RowFactory;
      + *   import org.apache.spark.sql.Row;
      + *
      + *   import org.apache.spark.ml.feature.*;
      + *   import org.apache.spark.ml.Pipeline;
      + *   import org.apache.spark.ml.PipelineStage;
      + *   import org.apache.spark.ml.PipelineModel;
      + *
      + *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
      + *  StructType schema = createStructType(
      + *    Arrays.asList(
      + *      createStructField("id", IntegerType, false),
      + *      createStructField("text", StringType, false),
      + *      createStructField("rating", DoubleType, false)));
      + *  JavaRDD rowRDD = jsc.parallelize(
      + *    Arrays.asList(
      + *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
      + *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
      + *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
      + *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
      + *  // define feature transformers
      + *  RegexTokenizer tok = new RegexTokenizer()
      + *    .setInputCol("text")
      + *    .setOutputCol("words");
      + *  StopWordsRemover sw = new StopWordsRemover()
      + *    .setInputCol("words")
      + *    .setOutputCol("filtered_words");
      + *  HashingTF tf = new HashingTF()
      + *    .setInputCol("filtered_words")
      + *    .setOutputCol("tf")
      + *    .setNumFeatures(10000);
      + *  IDF idf = new IDF()
      + *    .setInputCol("tf")
      + *    .setOutputCol("tf_idf");
      + *  VectorAssembler assembler = new VectorAssembler()
      + *    .setInputCols(new String[] {"tf_idf", "rating"})
      + *    .setOutputCol("features");
      + *
      + *  // assemble and fit the feature transformation pipeline
      + *  Pipeline pipeline = new Pipeline()
      + *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
      + *  PipelineModel model = pipeline.fit(df);
      + *
      + *  // save transformed features with raw data
      + *  model.transform(df)
      + *    .select("id", "text", "rating", "features")
      + *    .write().format("parquet").save("/output/path");
      + * 
      + * 
      + * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala new file mode 100644 index 000000000000..4571ab26800c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, VectorAssembler} +import org.apache.spark.sql.DataFrame + +/** + * == Feature transformers == + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as [[Transformer]]s, which transform one [[DataFrame]] + * into another, e.g., [[HashingTF]]. + * Some feature transformers are implemented as [[Estimator]]s, because the transformation requires + * some aggregated information of the dataset, e.g., document frequencies in [[IDF]]. + * For those feature transformers, calling [[Estimator!.fit]] is required to obtain the model first, + * e.g., [[IDFModel]], in order to apply transformation. + * The transformation is usually done by appending new columns to the input [[DataFrame]], so all + * input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * [[Pipeline]] can be used to chain feature transformers, and [[VectorAssembler]] can be used to + * combine multiple feature transformations, for example: + * + * {{{ + * import org.apache.spark.ml.feature._ + * import org.apache.spark.ml.Pipeline + * + * // a DataFrame with three columns: id (integer), text (string), and rating (double). + * val df = sqlContext.createDataFrame(Seq( + * (0, "Hi I heard about Spark", 3.0), + * (1, "I wish Java could use case classes", 4.0), + * (2, "Logistic regression models are neat", 4.0) + * )).toDF("id", "text", "rating") + * + * // define feature transformers + * val tok = new RegexTokenizer() + * .setInputCol("text") + * .setOutputCol("words") + * val sw = new StopWordsRemover() + * .setInputCol("words") + * .setOutputCol("filtered_words") + * val tf = new HashingTF() + * .setInputCol("filtered_words") + * .setOutputCol("tf") + * .setNumFeatures(10000) + * val idf = new IDF() + * .setInputCol("tf") + * .setOutputCol("tf_idf") + * val assembler = new VectorAssembler() + * .setInputCols(Array("tf_idf", "rating")) + * .setOutputCol("features") + * + * // assemble and fit the feature transformation pipeline + * val pipeline = new Pipeline() + * .setStages(Array(tok, sw, tf, idf, assembler)) + * val model = pipeline.fit(df) + * + * // save transformed features with raw data + * model.transform(df) + * .select("id", "text", "rating", "features") + * .write.format("parquet").save("/output/path") + * }}} + * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see [[http://scikit-learn.org/stable/modules/preprocessing.html scikit-learn.preprocessing]] + */ +package object feature diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala new file mode 100644 index 000000000000..8617722ae542 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[WeightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 + */ +private[ml] class WeightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double, + val diagInvAtWA: DenseVector) extends Serializable + +/** + * Weighted least squares solver via normal equation. + * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares + * formulation: + * + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i + * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * + * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by + * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * + * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to + * match R's `lm`. + * Turn on [[standardizeLabel]] to match R's `glmnet`. + * + * @param fitIntercept whether to fit intercept. If false, z is 0.0. + * @param regParam L2 regularization parameter (lambda) + * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * population standard deviation of the j-th column of A. Otherwise, + * sigma,,j,, is 1.0. + * @param standardizeLabel whether to standardize label. If true, delta is the population standard + * deviation of the label column b. Otherwise, delta is 1.0. + */ +private[ml] class WeightedLeastSquares( + val fitIntercept: Boolean, + val regParam: Double, + val standardizeFeatures: Boolean, + val standardizeLabel: Boolean) extends Logging with Serializable { + import WeightedLeastSquares._ + + require(regParam >= 0.0, s"regParam cannot be negative: $regParam") + if (regParam == 0.0) { + logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + + /** + * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. + */ + def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) + summary.validate() + logInfo(s"Number of instances: ${summary.count}.") + val k = if (fitIntercept) summary.k + 1 else summary.k + val triK = summary.triK + val wSum = summary.wSum + val bBar = summary.bBar + val bStd = summary.bStd + val aBar = summary.aBar + val aVar = summary.aVar + val abBar = summary.abBar + val aaBar = summary.aaBar + val aaValues = aaBar.values + + // add regularization to diagonals + var i = 0 + var j = 2 + while (i < triK) { + var lambda = regParam + if (standardizeFeatures) { + lambda *= aVar(j - 2) + } + if (standardizeLabel) { + // TODO: handle the case when bStd = 0 + lambda /= bStd + } + aaValues(i) += lambda + i += j + j += 1 + } + + val aa = if (fitIntercept) { + Array.concat(aaBar.values, aBar.values, Array(1.0)) + } else { + aaBar.values + } + val ab = if (fitIntercept) { + Array.concat(abBar.values, Array(bBar)) + } else { + abBar.values + } + + val x = CholeskyDecomposition.solve(aa, ab) + + val aaInv = CholeskyDecomposition.inverse(aa, k) + + // aaInv is a packed upper triangular matrix, here we get all elements on diagonal + val diagInvAtWA = new DenseVector((1 to k).map { i => + aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) + + val (coefficients, intercept) = if (fitIntercept) { + (new DenseVector(x.slice(0, x.length - 1)), x.last) + } else { + (new DenseVector(x), 0.0) + } + + new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + } +} + +private[ml] object WeightedLeastSquares { + + /** + * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. + */ + // TODO: consolidate aggregates for summary statistics + private class Aggregator extends Serializable { + var initialized: Boolean = false + var k: Int = _ + var count: Long = _ + var triK: Int = _ + var wSum: Double = _ + private var wwSum: Double = _ + private var bSum: Double = _ + private var bbSum: Double = _ + private var aSum: DenseVector = _ + private var abSum: DenseVector = _ + private var aaSum: DenseVector = _ + + private def init(k: Int): Unit = { + require(k <= 4096, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to 4096 but got $k.") + this.k = k + triK = k * (k + 1) / 2 + count = 0L + wSum = 0.0 + wwSum = 0.0 + bSum = 0.0 + bbSum = 0.0 + aSum = new DenseVector(Array.ofDim(k)) + abSum = new DenseVector(Array.ofDim(k)) + aaSum = new DenseVector(Array.ofDim(triK)) + initialized = true + } + + /** + * Adds an instance. + */ + def add(instance: Instance): this.type = { + val Instance(l, w, f) = instance + val ak = f.size + if (!initialized) { + init(ak) + } + assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") + count += 1L + wSum += w + wwSum += w * w + bSum += w * l + bbSum += w * l * l + BLAS.axpy(w, f, aSum) + BLAS.axpy(w * l, f, abSum) + BLAS.spr(w, f, aaSum) + this + } + + /** + * Merges another [[Aggregator]]. + */ + def merge(other: Aggregator): this.type = { + if (!other.initialized) { + this + } else { + if (!initialized) { + init(other.k) + } + assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}") + count += other.count + wSum += other.wSum + wwSum += other.wwSum + bSum += other.bSum + bbSum += other.bbSum + BLAS.axpy(1.0, other.aSum, aSum) + BLAS.axpy(1.0, other.abSum, abSum) + BLAS.axpy(1.0, other.aaSum, aaSum) + this + } + } + + /** + * Validates that we have seen observations. + */ + def validate(): Unit = { + assert(initialized, "Training dataset is empty.") + assert(wSum > 0.0, "Sum of weights cannot be zero.") + } + + /** + * Weighted mean of features. + */ + def aBar: DenseVector = { + val output = aSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of labels. + */ + def bBar: Double = bSum / wSum + + /** + * Weighted population standard deviation of labels. + */ + def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + + /** + * Weighted mean of (label * features). + */ + def abBar: DenseVector = { + val output = abSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of (features * features^T^). + */ + def aaBar: DenseVector = { + val output = aaSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted population variance of features. + */ + def aVar: DenseVector = { + val variance = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + variance(l) = aaValues(i) / wSum - aw * aw + i += j + j += 1 + } + new DenseVector(variance) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c..ee7e89edd879 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,8 +24,12 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -65,7 +69,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } @@ -73,7 +82,40 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali def w(value: T): ParamPair[T] = this -> value /** Creates a param pair with the given value (for Scala). */ + // scalastyle:off def ->(value: T): ParamPair[T] = ParamPair(this, value) + // scalastyle:on + + /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ + def jsonEncode(value: T): String = { + value match { + case x: String => + compact(render(JString(x))) + case v: Vector => + v.toJson + case _ => + throw new NotImplementedError( + "The default jsonEncode only supports string and vector. " + + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") + } + } + + /** Decodes a param value from JSON. */ + def jsonDecode(json: String): T = { + parse(json) match { + case JString(x) => + x.asInstanceOf[T] + case JObject(v) => + val keys = v.map(_._1) + assert(keys.contains("type") && keys.contains("values"), + s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") + Vectors.fromJson(json).asInstanceOf[T] + case _ => + throw new NotImplementedError( + "The default jsonDecode only supports string and vector. " + + s"${this.getClass.getName} must override jsonDecode to support its value type.") + } + } override final def toString: String = s"${parent}__$name" @@ -193,6 +235,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) + + override def jsonEncode(value: Double): String = { + compact(render(DoubleParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Double = { + DoubleParam.jValueDecode(parse(json)) + } +} + +private[param] object DoubleParam { + /** Encodes a param value into JValue. */ + def jValueEncode(value: Double): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Double.NegativeInfinity => + JString("-Inf") + case Double.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Double = { + jValue match { + case JString("NaN") => + Double.NaN + case JString("-Inf") => + Double.NegativeInfinity + case JString("Inf") => + Double.PositiveInfinity + case JDouble(x) => + x + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Double.") + } + } } /** @@ -213,6 +295,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) + + override def jsonEncode(value: Int): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Int = { + implicit val formats = DefaultFormats + parse(json).extract[Int] + } } /** @@ -233,6 +324,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) + + override def jsonEncode(value: Float): String = { + compact(render(FloatParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Float = { + FloatParam.jValueDecode(parse(json)) + } +} + +private object FloatParam { + + /** Encodes a param value into JValue. */ + def jValueEncode(value: Float): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Float.NegativeInfinity => + JString("-Inf") + case Float.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Float = { + jValue match { + case JString("NaN") => + Float.NaN + case JString("-Inf") => + Float.NegativeInfinity + case JString("Inf") => + Float.PositiveInfinity + case JDouble(x) => + x.toFloat + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Float.") + } + } } /** @@ -253,6 +385,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) + + override def jsonEncode(value: Long): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Long = { + implicit val formats = DefaultFormats + parse(json).extract[Long] + } } /** @@ -267,6 +408,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) + + override def jsonEncode(value: Boolean): String = { + compact(render(JBool(value))) + } + + override def jsonDecode(json: String): Boolean = { + implicit val formats = DefaultFormats + parse(json).extract[Boolean] + } } /** @@ -282,6 +432,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) + + override def jsonEncode(value: Array[String]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[String] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[String]].toArray + } } /** @@ -298,6 +458,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = w(value.asScala.map(_.asInstanceOf[Double]).toArray) + + override def jsonEncode(value: Array[Double]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(DoubleParam.jValueEncode))) + } + + override def jsonDecode(json: String): Array[Double] = { + parse(json) match { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].") + } + } } /** @@ -314,6 +488,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = w(value.asScala.map(_.asInstanceOf[Int]).toArray) + + override def jsonEncode(value: Array[Int]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[Int] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[Int]].toArray + } } /** @@ -418,7 +602,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } @@ -449,7 +633,7 @@ trait Params extends Identifiable with Serializable { /** * Clears the user-supplied value for the input param. */ - protected final def clear(param: Param[_]): this.type = { + final def clear(param: Param[_]): this.type = { shouldOwn(param) paramMap.remove(param) this @@ -461,7 +645,8 @@ trait Params extends Identifiable with Serializable { */ final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) - get(param).orElse(getDefault(param)).get + get(param).orElse(getDefault(param)).getOrElse( + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** An alias for [[getOrDefault()]]. */ @@ -559,13 +744,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f7ae1de522e0..c7bca1243092 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -42,25 +42,40 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + " probabilities. Note: Not all models output well-calibrated probability estimates!" + - " These probabilities should be treated as confidences, not precise probabilities.", + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", - isValid = "ParamValidators.inRange(0, 1)"), + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), + isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), + ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original probability" + + " of that class and t is the class' threshold.", + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", - isValid = "ParamValidators.gtEq(1)"), + ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + - " before fitting the model.", Some("true")), + " before fitting the model", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + - " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0."), + ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + + "empty, default value is 'auto'.", Some("\"auto\""))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" @@ -74,7 +89,8 @@ private[shared] object SharedParamsCodeGen { name: String, doc: String, defaultValueStr: Option[String] = None, - isValid: String = "") { + isValid: String = "", + finalMethods: Boolean = true) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -88,6 +104,7 @@ private[shared] object SharedParamsCodeGen { case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Boolean] => "BooleanParam" case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam" + case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam" case _ => s"Param[${getTypeString(c)}]" } } @@ -131,6 +148,11 @@ private[shared] object SharedParamsCodeGen { } else { "" } + val methodStr = if (param.finalMethods) { + "final def" + } else { + "def" + } s""" |/** @@ -145,7 +167,7 @@ private[shared] object SharedParamsCodeGen { | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ - | final def get$Name: $T = $$($name) + | $methodStr get$Name: $T = $$($name) |} |""".stripMargin } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 65e48e4ee508..cb2a060a34dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities") setDefault(probabilityCol, "probability") @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { @@ -149,8 +149,25 @@ private[ml] trait HasThreshold extends Params { */ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + setDefault(threshold, 0.5) + + /** @group getParam */ + def getThreshold: Double = $(threshold) +} + +/** + * Trait for shared param thresholds. + */ +private[ml] trait HasThresholds extends Params { + + /** + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + * @group param + */ + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) + /** @group getParam */ - final def getThreshold: Double = $(threshold) + def getThresholds: Array[Double] = $(thresholds) } /** @@ -206,10 +223,10 @@ private[ml] trait HasOutputCol extends Params { private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) @@ -232,16 +249,31 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ private[ml] trait HasStandardization extends Params { /** - * Param for whether to standardize the training features before fitting the model.. + * Param for whether to standardize the training features before fitting the model. * @group param */ - final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model") setDefault(standardization, true) @@ -272,10 +304,10 @@ private[ml] trait HasSeed extends Params { private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = $(elasticNetParam) @@ -310,4 +342,36 @@ private[ml] trait HasStepSize extends Params { /** @group getParam */ final def getStepSize: Double = $(stepSize) } + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} + +/** + * Trait for shared param solver (default: "auto"). + */ +private[ml] trait HasSolver extends Params { + + /** + * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.. + * @group param + */ + final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + setDefault(solver, "auto") + + /** @group getParam */ + final def getSolver: String = $(solver) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index f5a022c31ed9..4d82b90bfdf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -30,29 +30,58 @@ private[r] object SparkRWrappers { df: DataFrame, family: String, lambda: Double, - alpha: Double): PipelineModel = { + alpha: Double, + standardize: Boolean, + solver: String): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { case "gaussian" => new LinearRegression() .setRegParam(lambda) .setElasticNetParam(alpha) .setFitIntercept(formula.hasIntercept) + .setStandardization(standardize) + .setSolver(solver) case "binomial" => new LogisticRegression() .setRegParam(lambda) .setElasticNetParam(alpha) .setFitIntercept(formula.hasIntercept) + .setStandardization(standardize) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } - def getModelWeights(model: PipelineModel): Array[Double] = { + def getModelCoefficients(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => { + val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ + m.summary.coefficientStandardErrors.dropRight(1) + val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) + val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ + tValuesR ++ pValuesR + } else { + m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR + } + } + case m: LogisticRegressionModel => { + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray + } else { + m.coefficients.toArray + } + } + } + } + + def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => - Array(m.intercept) ++ m.weights.toArray - case _: LogisticRegressionModel => + m.summary.devianceResiduals + case m: LogisticRegressionModel => throw new UnsupportedOperationException( - "No weights available for LogisticRegressionModel") // SPARK-9492 + "No deviance residuals available for LogisticRegressionModel") } } @@ -61,10 +90,28 @@ private[r] object SparkRWrappers { case m: LinearRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No features names available for LogisticRegressionModel") // SPARK-9492 + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } + case m: LogisticRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } + } + } + + def getModelName(model: PipelineModel): String = { + model.stages.last match { + case m: LinearRegressionModel => + "LinearRegressionModel" + case m: LogisticRegressionModel => + "LogisticRegressionModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a2..b798aa1fab76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -26,16 +26,17 @@ import scala.util.Sorting import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} -import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.apache.hadoop.fs.{FileSystem, Path} -import org.netlib.util.intW +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -183,7 +184,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -219,10 +220,55 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends MLReadable[ALSModel] { + + @Since("1.6.0") + override def read: MLReader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = super.load(path) + + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = "rank" -> instance.rank + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private class ALSModelReader extends MLReader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[ALSModel].getName + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank = (metadata.metadata \ "rank").extract[Int] + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -255,7 +301,8 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -315,9 +362,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame): ALSModel = { import dataset.sqlContext.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), - col($(ratingCol)).cast(FloatType)) + .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } @@ -339,6 +386,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def copy(extra: ParamMap): ALS = defaultCopy(extra) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -348,7 +396,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -357,6 +405,9 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def load(path: String): ALS = super.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ @@ -366,8 +417,6 @@ object ALS extends Logging { /** Cholesky solver for least square problems. */ private[recommendation] class CholeskySolver extends LeastSquaresNESolver { - private val upper = "U" - /** * Solves a least squares problem with L2 regularization: * @@ -387,10 +436,7 @@ object ALS extends Logging { i += j j += 1 } - val info = new intW(0) - lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) - val code = info.`val` - assert(code == 0, s"lapack.dppsv returned $code.") + CholeskyDecomposition.solve(ne.ata, ne.atb) val x = new Array[Float](k) i = 0 while (i < k) { @@ -561,7 +607,7 @@ object ALS extends Logging { var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => - sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) val deletePreviousCheckpointFile: () => Unit = () => previousCheckpointFile.foreach { file => try { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 000000000000..aedfb48058dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Logging, SparkException} + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept with Logging { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range (0, 1) + * and the array should be non-empty. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1, false, false)) && t.length > 0) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)) + + /** + * Param for quantiles column name. + * This column will output quantiles of corresponding quantileProbabilities if it is set. + * @group param + */ + @Since("1.6.0") + final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name") + + /** @group getParam */ + @Since("1.6.0") + def getQuantilesCol: String = $(quantilesCol) + + /** Checks whether the input has quantiles column name. */ + protected[regression] def hasQuantilesCol: Boolean = { + isDefined(quantilesCol) && $(quantilesCol) != "" + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + if (hasQuantilesCol) { + SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The parameters vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialParameters = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialParameters.toBreeze.toDenseVector) + + val parameters = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val intercept = parameters(1) + val scale = math.exp(parameters(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + if (hasQuantilesCol) { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) + } else { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param parameters including three part: The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = parameters.slice(1, parameters.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(parameters(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (parameters). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4d30e4b5548a..477030d9ea3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} @@ -36,39 +36,50 @@ import org.apache.spark.sql.DataFrame * for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeRegressor(override val uid: String) +final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) // Override parameter setters from parent trait for Java API compatibility. - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } @@ -78,9 +89,11 @@ final class DecisionTreeRegressor(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ @@ -93,10 +106,12 @@ object DecisionTreeRegressor { * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ +@Since("1.4.0") @Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node) + override val rootNode: Node, + override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -107,18 +122,21 @@ final class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + private[ml] def this(rootNode: Node, numFeatures: Int) = + this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } /** Convert to a model in the old API */ @@ -133,12 +151,13 @@ private[ml] object DecisionTreeRegressionModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, - categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeRegressionModel = { require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode) + new DecisionTreeRegressionModel(uid, rootNode, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc320273..07144cc7cfbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} @@ -42,54 +42,65 @@ import org.apache.spark.sql.types.DoubleType * learning algorithm for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class GBTRegressor(override val uid: String) +final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } // Parameters from TreeEnsembleParams: - + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) } // Parameters from GBTParams: - + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTRegressor: @@ -100,6 +111,7 @@ final class GBTRegressor(override val uid: String) * (default = squared) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", @@ -108,9 +120,11 @@ final class GBTRegressor(override val uid: String) setDefault(lossType -> "squared") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -128,19 +142,23 @@ final class GBTRegressor(override val uid: String) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) + GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) } @@ -153,11 +171,13 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.4.0") @Experimental -final class GBTRegressionModel( +final class GBTRegressionModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] with TreeEnsembleModel with Serializable { @@ -165,8 +185,19 @@ final class GBTRegressionModel( require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTRegressionModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + @Since("1.4.0") + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -184,12 +215,15 @@ final class GBTRegressionModel( blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"GBTRegressionModel with $numTrees trees" + s"GBTRegressionModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ @@ -204,7 +238,8 @@ private[ml] object GBTRegressionModel { def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, - categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -212,6 +247,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 4ece8cf8cf0b..bbb1c7ac0a51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,44 +17,106 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DoubleType, DataType} -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** * Params for isotonic regression. */ -private[regression] trait IsotonicRegressionParams extends PredictorParams { +private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { /** - * Param for weight column name. - * TODO: Move weightCol to sharedParams. - * + * Param for whether the output sequence should be isotonic/increasing (true) or + * antitonic/decreasing (false). + * Default: true * @group param */ - final val weightCol: Param[String] = - new Param[String](this, "weightCol", "weight column name") + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false)") /** @group getParam */ - final def getWeightCol: String = $(weightCol) + final def getIsotonic: Boolean = $(isotonic) /** - * Param for isotonic parameter. - * Isotonic (increasing) or antitonic (decreasing) sequence. + * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no + * effect otherwise. * @group param */ - final val isotonic: BooleanParam = - new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + final val featureIndex: IntParam = new IntParam(this, "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect otherwise.") /** @group getParam */ - final def getIsotonicParam: Boolean = $(isotonic) + final def getFeatureIndex: Int = $(featureIndex) + + setDefault(isotonic -> true, featureIndex -> 0) + + /** Checks whether the input has weight column. */ + protected[ml] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol) != "" + } + + /** + * Extracts (label, feature, weight) from input dataset. + */ + protected[ml] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { + val idx = $(featureIndex) + val extract = udf { v: Vector => v(idx) } + extract(col($(featuresCol))) + } else { + col($(featuresCol)) + } + val w = if (hasWeightCol) { + col($(weightCol)) + } else { + lit(1.0) + } + dataset.select(col($(labelCol)), f, w) + .map { case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) + } + } + + /** + * Validates and transforms input schema. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected[ml] def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + if (hasWeightCol) { + SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) + } else { + logInfo("The weight column is not defined. Treat all instance weights as 1.0.") + } + } + val featuresType = schema($(featuresCol)).dataType + require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT]) + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } } /** @@ -66,56 +128,69 @@ private[regression] trait IsotonicRegressionParams extends PredictorParams { * * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Since("1.5.0") @Experimental -class IsotonicRegression(override val uid: String) - extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] - with IsotonicRegressionParams { +class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) - /** - * Set the isotonic parameter. - * Default is true. - * @group setParam - */ - def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) - setDefault(isotonic -> true) + /** @group setParam */ + @Since("1.5.0") + def setLabelCol(value: String): this.type = set(labelCol, value) - /** - * Set weight column param. - * Default is weight. - * @group setParam - */ - def setWeightParam(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "weight") + /** @group setParam */ + @Since("1.5.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override private[ml] def featuresDataType: DataType = DoubleType + /** @group setParam */ + @Since("1.5.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + /** @group setParam */ + @Since("1.5.0") + def setIsotonic(value: Boolean): this.type = set(isotonic, value) - private[this] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + /** @group setParam */ + @Since("1.5.0") + def setWeightCol(value: String): this.type = set(weightCol, value) - dataset.select($(labelCol), $(featuresCol), $(weightCol)) - .map { case Row(label: Double, features: Double, weights: Double) => - (label, features, weights) - } - } + /** @group setParam */ + @Since("1.5.0") + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + @Since("1.5.0") + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - override protected def train(dataset: DataFrame): IsotonicRegressionModel = { - SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + @Since("1.5.0") + override def fit(dataset: DataFrame): IsotonicRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) - val parentModel = isotonicRegression.run(instances) + val oldModel = isotonicRegression.run(instances) - new IsotonicRegressionModel(uid, parentModel) + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) + } + + @Since("1.5.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -123,22 +198,115 @@ class IsotonicRegression(override val uid: String) * * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. * - * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] - * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Since("1.5.0") +@Experimental class IsotonicRegressionModel private[ml] ( override val uid: String, - private[ml] val parentModel: MLlibIsotonicRegressionModel) - extends RegressionModel[Double, IsotonicRegressionModel] - with IsotonicRegressionParams { + private val oldModel: MLlibIsotonicRegressionModel) + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { - override def featuresDataType: DataType = DoubleType + /** @group setParam */ + @Since("1.5.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override protected def predict(features: Double): Double = { - parentModel.predict(features) - } + /** @group setParam */ + @Since("1.5.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.5.0") + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + /** Boundaries in increasing order for which predictions are known. */ + @Since("1.5.0") + def boundaries: Vector = Vectors.dense(oldModel.boundaries) + + /** + * Predictions associated with the boundaries at the same index, monotone because of isotonic + * regression. + */ + @Since("1.5.0") + def predictions: Vector = Vectors.dense(oldModel.predictions) + @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegressionModel = { - copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) + } + + @Since("1.5.0") + override def transform(dataset: DataFrame): DataFrame = { + val predict = dataset.schema($(featuresCol)).dataType match { + case DoubleType => + udf { feature: Double => oldModel.predict(feature) } + case _: VectorUDT => + val idx = $(featureIndex) + udf { features: Vector => oldModel.predict(features(idx)) } + } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + } + + @Since("1.5.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 3b85ba001b12..5e5850963edc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -19,33 +19,34 @@ package org.apache.spark.ml.regression import scala.collection.mutable -import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim.WeightedLeastSquares +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept + with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver /** * :: Experimental :: @@ -53,7 +54,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * * The learning objective is to minimize the squared error, with regularization. * The specific squared error loss function used is: - * L = 1/2n ||A weights - y||^2^ + * L = 1/2n ||A coefficients - y||^2^ * * This support multiple types of regularization: * - none (a.k.a. ordinary least squares) @@ -61,11 +62,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ +@Since("1.3.0") @Experimental -class LinearRegression(override val uid: String) +class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) /** @@ -73,6 +76,7 @@ class LinearRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.3.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -81,9 +85,23 @@ class LinearRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) + /** + * Whether to standardize the training features before fitting the model. + * The coefficients of models will be always returned on the original scale, + * so it will be transparent for users. Note that with/without standardization, + * the models should be always converged to the same solution when no regularization + * is applied. In R's GLMNET package, the default behavior is true as well. + * Default is true. + * @group setParam + */ + @Since("1.5.0") + def setStandardization(value: Boolean): this.type = set(standardization, value) + setDefault(standardization -> true) + /** * Set the ElasticNet mixing parameter. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -91,6 +109,7 @@ class LinearRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -99,6 +118,7 @@ class LinearRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.3.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -108,55 +128,126 @@ class LinearRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("1.6.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Set the solver algorithm used for optimization. + * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. + * The default value is "auto" which means that the solver algorithm is + * selected automatically. + * @group setParam + */ + @Since("1.6.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "auto") + override protected def train(dataset: DataFrame): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + // Extract the number of features before deciding optimization solver. + val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { + case Row(features: Vector) => features.size + }.first() + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || + $(solver) == "normal") { + require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + + "solver is used.'") + // For low dimensional data, WeightedLeastSquares is more efficiently since the + // training algorithm only requires one pass through the data. (SPARK-10668) + val instances: RDD[Instance] = dataset.select( + col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + $(standardization), true) + val model = optimizer.fit(instances) + // When it is trained by WeightedLeastSquares, training summary does not + // attached returned model. + val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) + // WeightedLeastSquares does not run through iterations. So it does not generate + // an objective history. + val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + summaryModel, + model.diagInvAtWA.toArray, + $(featuresCol), + Array(0D)) + + return lrModel.setSummary(trainingSummary) + } + + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) - - // If the yStd is zero, then the intercept is yMean with zero weights; + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) + + // If the yStd is zero, then the intercept is yMean with zero coefficient; // as a result, training is not needed. if (yStd == 0.0) { - logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + - s"and the intercept will be the mean of the label; as a result, training is not needed.") + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") if (handlePersistence) instances.unpersist() - val weights = Vectors.sparse(numFeatures, Seq()) + val coefficients = Vectors.sparse(numFeatures, Seq()) val intercept = yMean - val model = new LinearRegressionModel(uid, weights, intercept) + val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -165,19 +256,31 @@ class LinearRegression(override val uid: String) val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), - featuresStd, featuresMean, effectiveL2RegParam) + $(standardization), featuresStd, featuresMean, effectiveL2RegParam) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol)) + def effectiveL1RegFun = (index: Int) => { + if ($(standardization)) { + effectiveL1RegParam + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0 + } + } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) } - val initialWeights = Vectors.zeros(numFeatures) + val initialCoefficients = Vectors.zeros(numFeatures) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeights.toBreeze.toDenseVector) + initialCoefficients.toBreeze.toDenseVector) - val (weights, objectiveHistory) = { + val (coefficients, objectiveHistory) = { /* Note that in Linear Regression, the objective history (loss + regularization) returned from optimizer is computed in the scaled space given by the following formula. @@ -198,18 +301,18 @@ class LinearRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 - val len = rawWeights.length + val len = rawCoefficients.length while (i < len) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } i += 1 } - (Vectors.dense(rawWeights).compressed, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, arrayBuilder.result()) } /* @@ -217,41 +320,65 @@ class LinearRegression(override val uid: String) converged. See the following discussion for detail. http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet */ - val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0 + val intercept = if ($(fitIntercept)) { + yMean - dot(coefficients, Vectors.dense(featuresMean)) + } else { + 0.0 + } if (handlePersistence) instances.unpersist() - val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } +@Since("1.6.0") +object LinearRegression extends DefaultParamsReadable[LinearRegression] { + + @Since("1.6.0") + override def load(path: String): LinearRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[LinearRegression]]. */ +@Since("1.3.0") @Experimental class LinearRegressionModel private[ml] ( override val uid: String, - val weights: Vector, + val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + @deprecated("Use coefficients instead.", "1.6.0") + def weights: Vector = coefficients + + override val numFeatures: Int = coefficients.size + /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is * thrown if `trainingSummary == None`. */ + @Since("1.5.0") def summary: LinearRegressionTrainingSummary = trainingSummary match { case Some(summ) => summ case None => @@ -266,6 +393,7 @@ class LinearRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** @@ -274,40 +402,117 @@ class LinearRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { - val t = udf { features: Vector => predict(features) } - val predictionAndObservations = dataset - .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, + $(labelCol), this, Array(0D)) + } - new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } } + override protected def predict(features: Vector): Double = { - dot(features, weights) + intercept + dot(features, coefficients) + intercept } + @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { - val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) + } + + /** + * Returns a [[MLWriter]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends MLWriter with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[LinearRegressionModel].getName + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } } /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training coefficients except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ +@Since("1.5.0") @Experimental class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + model: LinearRegressionModel, + diagInvAtWA: Array[Double], val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { /** Number of training iterations until termination */ + @Since("1.5.0") val totalIterations = objectiveHistory.length } @@ -317,11 +522,14 @@ class LinearRegressionTrainingSummary private[regression] ( * Linear regression results evaluated on a dataset. * @param predictions predictions outputted by the model's `transform` method. */ +@Since("1.5.0") @Experimental class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, - val labelCol: String) extends Serializable { + val labelCol: String, + val model: LinearRegressionModel, + private val diagInvAtWA: Array[Double]) extends Serializable { @transient private val metrics = new RegressionMetrics( predictions @@ -332,39 +540,133 @@ class LinearRegressionSummary private[regression] ( * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") val r2: Double = metrics.r2 /** Residuals (label - predicted value) */ + @Since("1.5.0") @transient lazy val residuals: DataFrame = { val t = udf { (pred: Double, label: Double) => label - pred } predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + /** Number of instances in DataFrame predictions */ + lazy val numInstances: Long = predictions.count() + + /** Degrees of freedom */ + private val degreesOfFreedom: Long = if (model.getFitIntercept) { + numInstances - model.coefficients.size - 1 + } else { + numInstances - model.coefficients.size + } + + /** + * The weighted residuals, the usual residuals rescaled by + * the square root of the instance weights. + */ + lazy val devianceResiduals: Array[Double] = { + val weighted = if (model.getWeightCol.isEmpty) lit(1.0) else sqrt(col(model.getWeightCol)) + val dr = predictions.select(col(model.getLabelCol).minus(col(model.getPredictionCol)) + .multiply(weighted).as("weightedResiduals")) + .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) + .first() + Array(dr.getDouble(0), dr.getDouble(1)) + } + + /** + * Standard error of estimated coefficients and intercept. + */ + lazy val coefficientStandardErrors: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No Std. Error of coefficients available for this LinearRegressionModel") + } else { + val rss = if (model.getWeightCol.isEmpty) { + meanSquaredError * numInstances + } else { + val t = udf { (pred: Double, label: Double, weight: Double) => + math.pow(label - pred, 2.0) * weight } + predictions.select(t(col(model.getPredictionCol), col(model.getLabelCol), + col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) + } + val sigma2 = rss / degreesOfFreedom + diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + } + } + + /** + * T-statistic of estimated coefficients and intercept. + */ + lazy val tValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No t-statistic available for this LinearRegressionModel") + } else { + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + } + } + + /** + * Two-sided p-value of estimated coefficients and intercept. + */ + lazy val pValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No p-value available for this LinearRegressionModel") + } else { + tValues.map { x => 2.0 * (1.0 - StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + } + } + } /** @@ -377,7 +679,7 @@ class LinearRegressionSummary private[regression] ( * For improving the convergence rate during the optimization process, and also preventing against * features with very large variances exerting an overly large influence during model training, * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce - * the condition number, and then trains the model in scaled space but returns the weights in + * the condition number, and then trains the model in scaled space but returns the coefficients in * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf * * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache @@ -408,7 +710,7 @@ class LinearRegressionSummary private[regression] ( * + \bar{y} / \hat{y}||^2 * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 * }}} - * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is + * where w_i^\prime^ is the effective coefficients defined by w_i/\hat{x_i}, offset is * {{{ * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. * }}}, and diff is @@ -417,7 +719,7 @@ class LinearRegressionSummary private[regression] ( * }}} * * - * Note that the effective weights and offset don't depend on training dataset, + * Note that the effective coefficients and offset don't depend on training dataset, * so they can be precomputed. * * Now, the first derivative of the objective function in scaled space is @@ -453,14 +755,15 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. + * @param fitIntercept Whether to fit an intercept term. * @param featuresStd The standard deviation values of the features. * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -468,56 +771,62 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) /** - * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The instance of data point to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") - - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset - - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + def add(instance: Instance): this.type = { + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + + if (weight == 0.0) return this + + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this + totalCnt += 1 + weightSum += weight + this + } } /** @@ -532,8 +841,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -549,49 +859,80 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } /** * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. - * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It returns the loss and gradient with L2 regularization at a particular point (coefficients). * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + instances: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, + standardization: Boolean, featuresStd: Array[Double], featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeffs = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, - labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val leastSquaresAggregator = { + val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance) + val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2) - // regVal is the sum of weight squares for L2 regularization - val norm = brzNorm(weights, 2.0) - val regVal = 0.5 * effectiveL2regParam * norm * norm + instances.treeAggregate( + new LeastSquaresAggregator(coeffs, labelStd, labelMean, fitIntercept, featuresStd, + featuresMean))(seqOp, combOp) + } + + val totalGradientArray = leastSquaresAggregator.gradient.toArray - val loss = leastSquaresAggregator.loss + regVal - val gradient = leastSquaresAggregator.gradient - axpy(effectiveL2regParam, w, gradient) + val regVal = if (effectiveL2regParam == 0.0) { + 0.0 + } else { + var sum = 0.0 + coeffs.foreachActive { (index, value) => + // The following code will compute the loss of the regularization; also + // the gradient of the regularization, and add back to totalGradientArray. + sum += { + if (standardization) { + totalGradientArray(index) += effectiveL2regParam * value + value * value + } else { + if (featuresStd(index) != 0.0) { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val temp = value / (featuresStd(index) * featuresStd(index)) + totalGradientArray(index) += effectiveL2regParam * temp + value * temp + } else { + 0.0 + } + } + } + } + 0.5 * effectiveL2regParam * sum + } - (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 17fb1ad5e15d..71e40b513ee0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} @@ -30,51 +30,62 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class RandomForestRegressor(override val uid: String) +final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: - + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: - + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -87,18 +98,23 @@ final class RandomForestRegressor(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - new RandomForestRegressionModel(trees) + val numFeatures = oldDataset.first().features.size + new RandomForestRegressionModel(trees, numFeatures) } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -108,11 +124,14 @@ object RandomForestRegressor { * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. + * @param numFeatures Number of features used by this model */ +@Since("1.4.0") @Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeRegressionModel]) + private val _trees: Array[DecisionTreeRegressionModel], + override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -122,13 +141,16 @@ final class RandomForestRegressionModel private[ml] ( * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees) + private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + this(Identifiable.randomUID("rfr"), trees, numFeatures) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -146,14 +168,33 @@ final class RandomForestRegressionModel private[ml] ( _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"RandomForestRegressionModel with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) @@ -166,13 +207,14 @@ private[ml] object RandomForestRegressionModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, - categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): RandomForestRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees) + new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 000000000000..11b9815ecc83 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects + +import org.apache.spark.Logging +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path File path of LibSVM format + * @param numFeatures The number of features + * @param vectorType The type of vector. It can be 'sparse' or 'dense' + * @param sqlContext The Spark SQLContext + */ +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan with Logging with Serializable { + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil + ) + + override def buildScan(): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + val sparse = vectorType == "sparse" + baseRdd.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + Row(pt.label, features) + } + } + + override def hashCode(): Int = { + Objects.hashCode(path, Double.box(numFeatures), vectorType) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => + path == that.path && + numFeatures == that.numFeatures && + vectorType == that.vectorType + case _ => + false + } +} + +/** + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = sqlContext.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * DataFrame df = sqlContext.read().format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] + */ +@Since("1.6.0") +class DefaultSource extends RelationProvider with DataSourceRegister { + + @Since("1.6.0") + override def shortName(): String = "libsvm" + + @Since("1.6.0") + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { + val path = parameters.getOrElse("path", + throw new IllegalArgumentException("'path' must be specified")) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + val vectorType = parameters.getOrElse("vectorType", "sparse") + new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 8879352a600a..d89682611e3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable { * and probabilities. * For classification, the array of class counts must be normalized to a probability distribution. */ - private[tree] def impurityStats: ImpurityCalculator + private[ml] def impurityStats: ImpurityCalculator /** Recursive prediction helper method */ private[ml] def predictImpl(features: Vector): LeafNode @@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable { * @param id Node ID using old format IDs */ private[ml] def toOld(id: Int): OldNode + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int } private[ml] object Node { @@ -109,7 +115,7 @@ private[ml] object Node { final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -129,6 +135,8 @@ final class LeafNode private[ml] ( new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, isLeaf = true, None, None, None, None) } + + override private[ml] def maxSplitFeatureIndex(): Int = -1 } /** @@ -150,7 +158,7 @@ final class InternalNode private[ml] ( val leftChild: Node, val rightChild: Node, val split: Split, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -190,6 +198,11 @@ final class InternalNode private[ml] ( new OldPredict(leftChild.prediction, prob = 0.0), new OldPredict(rightChild.prediction, prob = 0.0)))) } + + override private[ml] def maxSplitFeatureIndex(): Int = { + math.max(split.featureIndex, + math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) + } } private object InternalNode { @@ -266,6 +279,43 @@ private[tree] class LearningNode( } } + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. + * + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to [[findBestSplits()]]. + */ + def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { + if (this.isLeaf || this.split.isEmpty) { + this.id + } else { + val split = this.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (this.leftChild.isEmpty) { + // Not yet split. Return next layer of nodes to train + if (splitLeft) { + LearningNode.leftChildIndex(this.id) + } else { + LearningNode.rightChildIndex(this.id) + } + } else { + if (splitLeft) { + this.leftChild.get.predictImpl(binnedFeatures, splits) + } else { + this.rightChild.get.predictImpl(binnedFeatures, splits) + } + } + } + } + } private[tree] object LearningNode { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dc..1ee01131d633 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -122,7 +122,7 @@ private[spark] class NodeIdCache( rddUpdateCount += 1 // Handle checkpointing if the directory is not None. - if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) { + if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) { // Let's see if we can delete previous checkpoints. var canDelete = true while (checkpointQueue.size > 1 && canDelete) { @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a8b90d9d266a..4a3b12d1440b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,6 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, @@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -72,7 +74,7 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val splits = findSplits(retaggedInput, metadata) + val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => @@ -177,67 +179,32 @@ private[ml] object RandomForest extends Logging { } } + val numFeatures = metadata.numFeatures + parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) } } } - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a leaf - * or unsplit node; that node's index is returned. - * - * @param node Node in tree from which to classify the given data point. - * @param binnedFeatures Binned feature vector for data point. - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * group of nodes on one call to [[findBestSplits()]]. - */ - private def predictNodeIndex( - node: LearningNode, - binnedFeatures: Array[Int], - splits: Array[Array[Split]]): Int = { - if (node.isLeaf || node.split.isEmpty) { - node.id - } else { - val split = node.split.get - val featureIndex = split.featureIndex - val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) - if (node.leftChild.isEmpty) { - // Not yet split. Return index from next layer of nodes to train - if (splitLeft) { - LearningNode.leftChildIndex(node.id) - } else { - LearningNode.rightChildIndex(node.id) - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftChild.get, binnedFeatures, splits) - } else { - predictNodeIndex(node.rightChild.get, binnedFeatures, splits) - } - } - } - } - /** * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. * @@ -445,8 +412,7 @@ private[ml] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits) + val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg @@ -849,6 +815,7 @@ private[ml] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata + * @param seed random seed * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -857,7 +824,8 @@ private[ml] object RandomForest extends Logging { */ protected[tree] def findSplits( input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): Array[Array[Split]] = { + metadata: DecisionTreeMetadata, + seed : Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -874,7 +842,7 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } else { new Array[LabeledPoint](0) } @@ -1113,4 +1081,94 @@ private[ml] object RandomForest extends Logging { } } + /** + * Given a Random Forest model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + * + * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for + * independently trained trees. + * @param trees Unweighted forest of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + + s" importance: No splits in forest, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + private[impl] def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * @param map Map with non-negative values. + */ + private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 22873909c33f..b77191156f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel { val header = toString + "\n" header + rootNode.subtreeToString(2) } + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index a0c5238d966b..1da97db9277d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams { +private[ml] trait DecisionTreeParams extends PredictorParams + with HasCheckpointInterval with HasSeed { /** * Maximum depth of the tree (>= 0). @@ -86,7 +87,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** * If false, the algorithm will pass trees to executors to match instances with nodes. * If true, the algorithm will cache node IDs for each instance. - * Caching can speed up training of deeper trees. + * Caching can speed up training of deeper trees. Users can set how often should the + * cache be checkpointed or disable it by setting checkpointInterval. * (default = false) * @group expertParam */ @@ -95,21 +97,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams { " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * Specifies how often to checkpoint the cached node IDs. - * E.g. 10 means that the cache will get checkpointed every 10 iterations. - * This is only used if cacheNodeIds is true and if the checkpoint directory is set in - * [[org.apache.spark.SparkContext]]. - * Must be >= 1. - * (default = 10) - * @group expertParam - */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + - " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + - " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.", - ParamValidators.gtEq(1)) - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) @@ -137,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -149,12 +139,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group expertSetParam */ + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertSetParam + */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** @group expertGetParam */ - final def getCheckpointInterval: Int = $(checkpointInterval) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -162,7 +157,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { oldAlgo: OldAlgo.Algo, oldImpurity: OldImpurity, subsamplingRate: Double): OldStrategy = { - val strategy = OldStrategy.defaultStategy(oldAlgo) + val strategy = OldStrategy.defaultStrategy(oldAlgo) strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval strategy.maxBins = getMaxBins @@ -266,7 +261,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -285,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. @@ -374,17 +366,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -402,11 +384,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4b5..40f8857fc586 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,21 +18,29 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ +import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams { +private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -51,26 +59,38 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { * :: Experimental :: * K-fold cross validation. */ +@Since("1.2.0") @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { +class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) + extends Estimator[CrossValidatorModel] + with CrossValidatorParams with MLWritable with Logging { + @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS /** @group setParam */ + @Since("1.2.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.2.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.2.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -80,7 +100,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -100,17 +120,21 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -119,6 +143,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] if (copied.isDefined(estimator)) { @@ -129,37 +154,256 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** * :: Experimental :: * Model from k-fold cross validation. + * + * @param bestModel The best model selected from k-fold cross validation. + * @param avgMetrics Average cross-validation metrics for each paramMap in + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ +@Since("1.2.0") @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + @Since("1.4.0") override val uid: String, + @Since("1.2.0") val bestModel: Model[_], + @Since("1.5.0") val avgMetrics: Array[Double]) + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + @Since("1.4.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index 98a8f0330ca4..b836d2a2340e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ /** * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ +@Since("1.2.0") @Experimental -class ParamGridBuilder { +class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") def baseOn(paramMap: ParamMap): this.type = { baseOn(paramMap.toSeq: _*) this @@ -43,6 +45,7 @@ class ParamGridBuilder { /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") @varargs def baseOn(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -54,6 +57,7 @@ class ParamGridBuilder { /** * Adds a param with multiple values (overwrites if the input param exists). */ + @Since("1.2.0") def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { paramGrid.put(param, values) this @@ -64,6 +68,7 @@ class ParamGridBuilder { /** * Adds a double param with multiple values. */ + @Since("1.2.0") def addGrid(param: DoubleParam, values: Array[Double]): this.type = { addGrid[Double](param, values) } @@ -71,6 +76,7 @@ class ParamGridBuilder { /** * Adds a int param with multiple values. */ + @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { addGrid[Int](param, values) } @@ -78,6 +84,7 @@ class ParamGridBuilder { /** * Adds a float param with multiple values. */ + @Since("1.2.0") def addGrid(param: FloatParam, values: Array[Float]): this.type = { addGrid[Float](param, values) } @@ -85,6 +92,7 @@ class ParamGridBuilder { /** * Adds a long param with multiple values. */ + @Since("1.2.0") def addGrid(param: LongParam, values: Array[Long]): this.type = { addGrid[Long](param, values) } @@ -92,6 +100,7 @@ class ParamGridBuilder { /** * Adds a boolean param with true and false. */ + @Since("1.2.0") def addGrid(param: BooleanParam): this.type = { addGrid[Boolean](param, Array(true, false)) } @@ -99,6 +108,7 @@ class ParamGridBuilder { /** * Builds and returns all combinations of parameters specified by the param grid. */ + @Since("1.2.0") def build(): Array[ParamMap] = { var paramMaps = Array(new ParamMap) paramGrid.foreach { case (param, values) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c0edc730b6fd..adf06302047a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} @@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ +@Since("1.5.0") @Experimental -class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] +class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with Logging { + @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) /** @group setParam */ + @Since("1.5.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.5.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -99,17 +107,21 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.5.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -118,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali } } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] if (copied.isDefined(estimator)) { @@ -138,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali * @param bestModel Estimator determined best model. * @param validationMetrics Evaluated validation metrics. */ +@Since("1.5.0") @Experimental class TrainValidationSplitModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val validationMetrics: Array[Double]) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val bestModel: Model[_], + @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + @Since("1.5.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplitModel = { val copied = new TrainValidationSplitModel ( uid, diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index ddd34a54503a..bd213e7362e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,11 +19,19 @@ package org.apache.spark.ml.util import java.util.UUID +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: + * * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. */ -private[spark] trait Identifiable { +@DeveloperApi +trait Identifiable { /** * An immutable unique ID for the object and its derivatives. @@ -33,7 +41,11 @@ private[spark] trait Identifiable { override def toString: String = uid } -private[spark] object Identifiable { +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 2a1db90f2ca2..96a38a3bde96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.ml.attribute._ +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.StructField /** - * Helper utilities for tree-based algorithms + * Helper utilities for algorithms using ML metadata */ private[spark] object MetadataUtils { @@ -74,4 +75,20 @@ private[spark] object MetadataUtils { } } + /** + * Takes a Vector column and a list of feature names, and returns the corresponding list of + * feature indices in the column, in order. + * @param col Vector column which must have feature names specified via attributes + * @param names List of feature names + */ + def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = { + require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col" + + s" to be Vector type, but it was type ${col.dataType} instead.") + val inputAttr = AttributeGroup.fromStructField(col) + names.map { name => + require(inputAttr.hasAttr(name), + s"getFeatureIndicesFromNames found no feature with name $name in column $col.") + inputAttr.getAttr(name).index.get + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala new file mode 100644 index 000000000000..8484b1f80106 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +/** + * Trait for [[MLWriter]] and [[MLReader]]. + */ +private[util] sealed trait BaseReadWrite { + private var optionSQLContext: Option[SQLContext] = None + + /** + * Sets the SQL context to use for saving/loading. + */ + @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = { + if (optionSQLContext.isEmpty) { + optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + } + optionSQLContext.get + } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class MLWriter extends BaseReadWrite with Logging { + + protected var shouldOverwrite: Boolean = false + + /** + * Saves the ML instances to the input path. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit + + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for classes that provide [[MLWriter]]. + */ +@Since("1.6.0") +trait MLWritable { + + /** + * Returns an [[MLWriter]] instance for this ML instance. + */ + @Since("1.6.0") + def write: MLWriter + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) +} + +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + +/** + * Abstract class for utility classes that can load ML instances. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class MLReader[T] extends BaseReadWrite { + + /** + * Loads the ML component from the input path. + */ + @Since("1.6.0") + def load(path: String): T + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for objects that provide [[MLReader]]. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait MLReadable[T] { + + /** + * Returns an [[MLReader]] instance for this class. + */ + @Since("1.6.0") + def read: MLReader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + * + * Note: Implementing classes should override this to be Java-friendly. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) +} + +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + +/** + * Default [[MLWriter]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @param instance object to save + */ +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + } +} + +private[ml] object DefaultParamsWriter { + + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. + */ + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList)) + val basicMetadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } +} + +/** + * Default [[MLReader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @tparam T ML instance type + * TODO: Consider adding check for correct class name. + */ +private[ml] class DefaultParamsReader[T] extends MLReader[T] { + + override def load(path: String): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[T] + } +} + +private[ml] object DefaultParamsReader { + + /** + * All info from metadata file. + * @param params paramMap, as a [[JValue]] + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) + */ + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadata: JValue, + metadataJson: String) + + /** + * Load metadata from file. + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] + val uid = (metadata \ "uid").extract[String] + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 0ec88ef77d69..6a3b20c88d2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,14 +17,11 @@ package org.apache.spark.mllib.api.python -import java.util.{List => JList} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} import org.apache.spark.mllib.clustering.GaussianMixtureModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -36,17 +33,11 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { /** * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian */ - val gaussians: JList[Object] = { - val modelGaussians = model.gaussians - var i = 0 - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - while (i < k) { - mu += modelGaussians(i).mu - sigma += modelGaussians(i).sigma - i += 1 + val gaussians: Array[Byte] = { + val modelGaussians = model.gaussians.map { gaussian => + Array[Any](gaussian.mu, gaussian.sigma) } - List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 000000000000..63282eee6e65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala index 534edac56bc5..eeb7cba882ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -42,4 +42,12 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization case (product, feature) => (product, Vectors.dense(feature)) }.asInstanceOf[RDD[(Any, Any)]]) } + + def wrappedRecommendProductsForUsers(num: Int): RDD[Array[Any]] = { + SerDe.fromTuple2RDD(recommendProductsForUsers(num).asInstanceOf[RDD[(Any, Any)]]) + } + + def wrappedRecommendUsersForProducts(num: Int): RDD[Array[Any]] = { + SerDe.fromTuple2RDD(recommendUsersForProducts(num).asInstanceOf[RDD[(Any, Any)]]) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala new file mode 100644 index 000000000000..0027602a04f8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.mllib.fpm.PrefixSpanModel +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of PrefixSpanModel to provide helper method for Python + */ +private[python] class PrefixSpanModelWrapper(model: PrefixSpanModel[Any]) + extends PrefixSpanModel(model.freqSequences) { + + def getFreqSequences: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqSequences.map(x => (x.javaSequence, x.freq))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6f080d32bbf4..29160a10e16b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -35,8 +35,9 @@ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.mllib.feature._ -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel, PrefixSpan} import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ @@ -54,7 +55,7 @@ import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomFo import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -131,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) .setValidateData(validateData) @@ -140,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, @@ -158,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.setIntercept(intercept) .setValidateData(validateData) @@ -167,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( lassoAlg, data, @@ -184,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.setIntercept(intercept) .setValidateData(validateData) @@ -193,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( ridgeAlg, data, @@ -211,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) .setValidateData(validateData) @@ -220,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, @@ -239,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) .setValidateData(validateData) @@ -248,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, @@ -325,7 +336,8 @@ private[python] class PythonMLLibAPI extends Serializable { initializationMode: String, seed: java.lang.Long, initializationSteps: Int, - epsilon: Double): KMeansModel = { + epsilon: Double, + initialModel: java.util.ArrayList[Vector]): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) @@ -335,6 +347,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setEpsilon(epsilon) if (seed != null) kMeansAlg.setSeed(seed) + if (!initialModel.isEmpty()) kMeansAlg.setInitialModel(new KMeansModel(initialModel)) try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) @@ -504,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -522,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } @@ -544,6 +566,27 @@ private[python] class PythonMLLibAPI extends Serializable { new FPGrowthModelWrapper(model) } + /** + * Java stub for Python mllib PrefixSpan.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainPrefixSpanModel( + data: JavaRDD[java.util.ArrayList[java.util.ArrayList[Any]]], + minSupport: Double, + maxPatternLength: Int, + localProjDBSize: Int ): PrefixSpanModelWrapper = { + val prefixSpan = new PrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(localProjDBSize) + + val trainData = data.rdd.map(_.asScala.toArray.map(_.asScala.toArray)) + val model = prefixSpan.run(trainData) + new PrefixSpanModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -637,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable { } } - private[python] class Word2VecModelWrapper(model: Word2VecModel) { - def transform(word: String): Vector = { - model.transform(word) - } - - /** - * Transforms an RDD of words to its vector representation - * @param rdd an RDD of words - * @return an RDD of vector representations of words - */ - def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { - rdd.rdd.map(model.transform) - } - - def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) - } - - def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) - val similarity = Vectors.dense(result.map(_._2)) - val words = result.map(_._1) - List(words, similarity).map(_.asInstanceOf[Object]).asJava - } - - def getVectors: JMap[String, JList[Float]] = { - model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava - } - - def save(sc: SparkContext, path: String): Unit = model.save(sc, path) - } - /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. @@ -1096,6 +1106,81 @@ private[python] class PythonMLLibAPI extends Serializable { Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*) } + /** + * Wrapper around RowMatrix constructor. + */ + def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { + new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) + } + + /** + * Wrapper around IndexedRowMatrix constructor. + */ + def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { + // We use DataFrames for serialization of IndexedRows from Python, + // so map each Row in the DataFrame back to an IndexedRow. + val indexedRows = rows.map { + case Row(index: Long, vector: Vector) => IndexedRow(index, vector) + } + new IndexedRowMatrix(indexedRows, numRows, numCols) + } + + /** + * Wrapper around CoordinateMatrix constructor. + */ + def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { + // We use DataFrames for serialization of MatrixEntry entries from + // Python, so map each Row in the DataFrame back to a MatrixEntry. + val entries = rows.map { + case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) + } + new CoordinateMatrix(entries, numRows, numCols) + } + + /** + * Wrapper around BlockMatrix constructor. + */ + def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int, + numRows: Long, numCols: Long): BlockMatrix = { + // We use DataFrames for serialization of sub-matrix blocks from + // Python, so map each Row in the DataFrame back to a + // ((blockRowIndex, blockColIndex), sub-matrix) tuple. + val blockTuples = blocks.map { + case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => + ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) + } + new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols) + } + + /** + * Return the rows of an IndexedRowMatrix. + */ + def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { + // We use DataFrames for serialization of IndexedRows to Python, + // so return a DataFrame. + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) + sqlContext.createDataFrame(indexedRowMatrix.rows) + } + + /** + * Return the entries of a CoordinateMatrix. + */ + def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { + // We use DataFrames for serialization of MatrixEntry entries to + // Python, so return a DataFrame. + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) + sqlContext.createDataFrame(coordinateMatrix.entries) + } + + /** + * Return the sub-matrix blocks of a BlockMatrix. + */ + def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { + // We use DataFrames for serialization of sub-matrix blocks to + // Python, so return a DataFrame. + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) + sqlContext.createDataFrame(blockMatrix.blocks) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala new file mode 100644 index 000000000000..0f55980481dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Wrapper around Word2VecModel to provide helper methods in Python + */ +private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index ba73024e3c04..5161bc72659c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -19,25 +19,24 @@ package org.apache.spark.mllib.classification import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Represents a classification model that predicts to which of a set of categories an example * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. */ -@Experimental +@Since("0.8.0") trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -45,16 +44,16 @@ trait ClassificationModel extends Serializable { * * @param testData array representing a single data point * @return predicted category from the trained model - * @since 0.8.0 */ + @Since("1.0.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 268642ac6a2f..2d52abc122bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD * Multinomial Logistic Regression. By default, it is binary logistic regression * so numClasses will be set to 2. */ -class LogisticRegressionModel ( - override val weights: Vector, - override val intercept: Double, - val numFeatures: Int, - val numClasses: Int) +@Since("0.8.0") +class LogisticRegressionModel @Since("1.3.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("1.0.0") override val intercept: Double, + @Since("1.3.0") val numFeatures: Int, + @Since("1.3.0") val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { @@ -75,40 +76,35 @@ class LogisticRegressionModel ( /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ + @Since("1.0.0") def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) /** - * :: Experimental :: * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. - * @since 1.0.0 */ - @Experimental + @Since("1.0.0") def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } /** - * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. * It is only used for binary classification. - * @since 1.3.0 */ - @Experimental + @Since("1.3.0") def getThreshold: Option[Double] = threshold /** - * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. * It is only used for binary classification. - * @since 1.0.0 */ - @Experimental + @Since("1.0.0") def clearThreshold(): this.type = { threshold = None this @@ -158,9 +154,7 @@ class LogisticRegressionModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures, numClasses, weights, intercept, threshold) @@ -168,19 +162,15 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object LogisticRegressionModel extends Loader[LogisticRegressionModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -213,6 +203,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ +@Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -222,6 +213,7 @@ class LogisticRegressionWithSGD private[mllib] ( private val gradient = new LogisticGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -233,6 +225,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -244,6 +237,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ +@Since("0.8.0") object LogisticRegressionWithSGD { // NOTE(shivaram): We use multiple train methods instead of default arguments to support // Java programs. @@ -261,8 +255,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -284,8 +278,8 @@ object LogisticRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -306,8 +300,8 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -324,8 +318,8 @@ object LogisticRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int): LogisticRegressionModel = { @@ -339,11 +333,13 @@ object LogisticRegressionWithSGD { * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. */ +@Since("1.1.0") class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { this.setFeatureScaling(true) + @Since("1.1.0") override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(multiLabelValidator) @@ -357,13 +353,11 @@ class LogisticRegressionWithLBFGS } /** - * :: Experimental :: * Set the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. - * @since 1.3.0 */ - @Experimental + @Since("1.3.0") def setNumClasses(numClasses: Int): this.type = { require(numClasses > 1) numOfLinearPredictor = numClasses - 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 2df91c09421e..aef9ef2cb052 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} @@ -40,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * where D is number of features * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ +@Since("0.9.0") class NaiveBayesModel private[spark] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], - val modelType: String) + @Since("1.0.0") val labels: Array[Double], + @Since("0.9.0") val pi: Array[Double], + @Since("0.9.0") val theta: Array[Array[Double]], + @Since("1.4.0") val modelType: String) extends ClassificationModel with Serializable with Saveable { import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} @@ -82,6 +84,7 @@ class NaiveBayesModel private[spark] ( throw new UnknownError(s"Invalid modelType: $modelType.") } + @Since("1.0.0") override def predict(testData: RDD[Vector]): RDD[Double] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -90,6 +93,7 @@ class NaiveBayesModel private[spark] ( } } + @Since("1.0.0") override def predict(testData: Vector): Double = { modelType match { case Multinomial => @@ -106,6 +110,7 @@ class NaiveBayesModel private[spark] ( * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -121,6 +126,7 @@ class NaiveBayesModel private[spark] ( * @return predicted posterior class probabilities from the trained model, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: Vector): Vector = { modelType match { case Multinomial => @@ -157,6 +163,7 @@ class NaiveBayesModel private[spark] ( new DenseVector(scaledProbs.map(_ / probSum)) } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) @@ -165,6 +172,7 @@ class NaiveBayesModel private[spark] ( override protected def formatVersion: String = "2.0" } +@Since("1.3.0") object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ @@ -184,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -198,8 +206,9 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { dataRDD.write.parquet(dataPath(path)) } + @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -230,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -245,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -300,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ - +@Since("0.9.0") class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + @Since("0.9.0") def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ + @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { this.lambda = lambda this } /** Get the smoothing parameter. */ + @Since("1.4.0") def getLambda: Double = lambda /** * Set the model type using a string (case-sensitive). * Supported options: "multinomial" (default) and "bernoulli". */ + @Since("1.4.0") def setModelType(modelType: String): NaiveBayes = { require(NaiveBayes.supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") @@ -332,6 +346,7 @@ class NaiveBayes private ( } /** Get the model type. */ + @Since("1.4.0") def getModelType: String = this.modelType /** @@ -339,6 +354,7 @@ class NaiveBayes private ( * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ + @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -422,6 +438,7 @@ class NaiveBayes private ( /** * Top-level methods for calling naive Bayes. */ +@Since("0.9.0") object NaiveBayes { /** String name for multinomial model type. */ @@ -444,8 +461,8 @@ object NaiveBayes { * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint]): NaiveBayesModel = { new NaiveBayes().run(input) } @@ -460,8 +477,8 @@ object NaiveBayes { * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda, Multinomial).run(input) } @@ -483,8 +500,8 @@ object NaiveBayes { * * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli - * @since 0.9.0 */ + @Since("1.4.0") def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5b54feeb1046..a8d3fd4177a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ @@ -33,41 +33,36 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class SVMModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class SVMModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { private var threshold: Option[Double] = Some(0.0) /** - * :: Experimental :: * Sets the threshold that separates positive predictions from negative predictions. An example * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. - * @since 1.3.0 */ - @Experimental + @Since("1.0.0") def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } /** - * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. - * @since 1.3.0 */ - @Experimental + @Since("1.3.0") def getThreshold: Option[Double] = threshold /** - * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. - * @since 1.0.0 */ - @Experimental + @Since("1.0.0") def clearThreshold(): this.type = { threshold = None this @@ -84,9 +79,7 @@ class SVMModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) @@ -94,19 +87,15 @@ class SVMModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object SVMModel extends Loader[SVMModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): SVMModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -138,6 +127,7 @@ object SVMModel extends Loader[SVMModel] { * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") class SVMWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -147,6 +137,7 @@ class SVMWithSGD private ( private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -158,6 +149,7 @@ class SVMWithSGD private ( * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, * regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -168,6 +160,7 @@ class SVMWithSGD private ( /** * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") object SVMWithSGD { /** @@ -185,8 +178,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -209,8 +202,8 @@ object SVMWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -231,8 +224,8 @@ object SVMWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -250,8 +243,8 @@ object SVMWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 0.01, 1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 7d33df3221fb..47bff5ebdde4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.StreamingLinearAlgorithm /** - * :: Experimental :: * Train or predict a logistic regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation) @@ -43,7 +42,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * .trainOn(DStream) * }}} */ -@Experimental +@Since("1.3.0") class StreamingLogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -58,6 +57,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.3.0") def this() = this(0.1, 50, 1.0, 0.0) protected val algorithm = new LogisticRegressionWithSGD( @@ -66,30 +66,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected var model: Option[LogisticRegressionModel] = None /** Set the step size for gradient descent. Default: 0.1. */ + @Since("1.3.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.3.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } /** Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.3.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } /** Set the regularization parameter. Default: 0.0. */ + @Since("1.3.0") def setRegParam(regParam: Double): this.type = { this.algorithm.optimizer.setRegParam(regParam) this } /** Set the initial weights. Default: [0.0, 0.0]. */ + @Since("1.3.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d28..2910c027ae06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala new file mode 100644 index 000000000000..54bf5102cc56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if + * there are no divisible leaf clusters. + * @param maxIterations the max number of k-means iterations to split clusters (default: 20) + * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1) + * @param seed a random seed (default: hash value of the class name) + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("1.6.0") +@Experimental +class BisectingKMeans private ( + private var k: Int, + private var maxIterations: Int, + private var minDivisibleClusterSize: Double, + private var seed: Long) extends Logging { + + import BisectingKMeans._ + + /** + * Constructs with the default configuration + */ + @Since("1.6.0") + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + + /** + * Sets the desired number of leaf clusters (default: 4). + * The actual number could be smaller if there are no divisible leaf clusters. + */ + @Since("1.6.0") + def setK(k: Int): this.type = { + require(k > 0, s"k must be positive but got $k.") + this.k = k + this + } + + /** + * Gets the desired number of leaf clusters. + */ + @Since("1.6.0") + def getK: Int = this.k + + /** + * Sets the max number of k-means iterations to split clusters (default: 20). + */ + @Since("1.6.0") + def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.") + this.maxIterations = maxIterations + this + } + + /** + * Gets the max number of k-means iterations to split clusters. + */ + @Since("1.6.0") + def getMaxIterations: Int = this.maxIterations + + /** + * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster (default: 1). + */ + @Since("1.6.0") + def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = { + require(minDivisibleClusterSize > 0.0, + s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.") + this.minDivisibleClusterSize = minDivisibleClusterSize + this + } + + /** + * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster. + */ + @Since("1.6.0") + def getMinDivisibleClusterSize: Double = minDivisibleClusterSize + + /** + * Sets the random seed (default: hash value of the class name). + */ + @Since("1.6.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Gets the random seed. + */ + @Since("1.6.0") + def getSeed: Long = this.seed + + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + if (input.getStorageLevel == StorageLevel.NONE) { + logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + + " its parent RDDs are also not cached.") + } + val d = input.map(_.size).first() + logInfo(s"Feature dimension: $d.") + // Compute and cache vector norms for fast distance computation. + val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) + val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } + var assignments = vectors.map(v => (ROOT_INDEX, v)) + var activeClusters = summarize(d, assignments) + val rootSummary = activeClusters(ROOT_INDEX) + val n = rootSummary.size + logInfo(s"Number of points: $n.") + logInfo(s"Initial cost: ${rootSummary.cost}.") + val minSize = if (minDivisibleClusterSize >= 1.0) { + math.ceil(minDivisibleClusterSize).toLong + } else { + math.ceil(minDivisibleClusterSize * n).toLong + } + logInfo(s"The minimum number of points of a divisible cluster is $minSize.") + var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)] + val random = new Random(seed) + var numLeafClustersNeeded = k - 1 + var level = 1 + while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) { + // Divisible clusters are sufficiently large and have non-trivial cost. + var divisibleClusters = activeClusters.filter { case (_, summary) => + (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size) + } + // If we don't need all divisible clusters, take the larger ones. + if (divisibleClusters.size > numLeafClustersNeeded) { + divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) => + -summary.size + }.take(numLeafClustersNeeded) + .toMap + } + if (divisibleClusters.nonEmpty) { + val divisibleIndices = divisibleClusters.keys.toSet + logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") + var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => + val (left, right) = splitCenter(summary.center, random) + Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) + }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map + var newClusters: Map[Long, ClusterSummary] = null + var newAssignments: RDD[(Long, VectorWithNorm)] = null + for (iter <- 0 until maxIterations) { + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + .filter { case (index, _) => + divisibleIndices.contains(parentIndex(index)) + } + newClusters = summarize(d, newAssignments) + newClusterCenters = newClusters.mapValues(_.center).map(identity) + } + // TODO: Unpersist old indices. + val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + .persist(StorageLevel.MEMORY_AND_DISK) + assignments = indices.zip(vectors) + inactiveClusters ++= activeClusters + activeClusters = newClusters + numLeafClustersNeeded -= divisibleClusters.size + } else { + logInfo(s"None active and divisible clusters left on level $level. Stop iterations.") + inactiveClusters ++= activeClusters + activeClusters = Map.empty + } + level += 1 + } + val clusters = activeClusters ++ inactiveClusters + val root = buildTree(clusters) + new BisectingKMeansModel(root) + } + + /** + * Java-friendly version of [[run()]]. + */ + def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) +} + +private object BisectingKMeans extends Serializable { + + /** The index of the root node of a tree. */ + private val ROOT_INDEX: Long = 1 + + private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2 + + private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2) + + /** Returns the left child index of the given node index. */ + private def leftChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.") + 2 * index + } + + /** Returns the right child index of the given node index. */ + private def rightChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.") + 2 * index + 1 + } + + /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */ + private def parentIndex(index: Long): Long = { + index / 2 + } + + /** + * Summarizes data by each cluster as Map. + * @param d feature dimension + * @param assignments pairs of point and its cluster index + * @return a map from cluster indices to corresponding cluster summaries + */ + private def summarize( + d: Int, + assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + seqOp = (agg, v) => agg.add(v), + combOp = (agg1, agg2) => agg1.merge(agg2) + ).mapValues(_.summary) + .collect().toMap + } + + /** + * Cluster summary aggregator. + * @param d feature dimension + */ + private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private var n: Long = 0L + private val sum: Vector = Vectors.zeros(d) + private var sumSq: Double = 0.0 + + /** Adds a point. */ + def add(v: VectorWithNorm): this.type = { + n += 1L + // TODO: use a numerically stable approach to estimate cost + sumSq += v.norm * v.norm + BLAS.axpy(1.0, v.vector, sum) + this + } + + /** Merges another aggregator. */ + def merge(other: ClusterSummaryAggregator): this.type = { + n += other.n + sumSq += other.sumSq + BLAS.axpy(1.0, other.sum, sum) + this + } + + /** Returns the summary. */ + def summary: ClusterSummary = { + val mean = sum.copy + if (n > 0L) { + BLAS.scal(1.0 / n, mean) + } + val center = new VectorWithNorm(mean) + val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) + new ClusterSummary(n, center, cost) + } + } + + /** + * Bisects a cluster center. + * + * @param center current cluster center + * @param random a random number generator + * @return initial centers + */ + private def splitCenter( + center: VectorWithNorm, + random: Random): (VectorWithNorm, VectorWithNorm) = { + val d = center.vector.size + val norm = center.norm + val level = 1e-4 * norm + val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) + val left = center.vector.copy + BLAS.axpy(-level, noise, left) + val right = center.vector.copy + BLAS.axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * Updates assignments. + * @param assignments current assignments + * @param divisibleIndices divisible cluster indices + * @param newClusterCenters new cluster centers + * @return new assignments + */ + private def updateAssignments( + assignments: RDD[(Long, VectorWithNorm)], + divisibleIndices: Set[Long], + newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + assignments.map { case (index, v) => + if (divisibleIndices.contains(index)) { + val children = Seq(leftChildIndex(index), rightChildIndex(index)) + val selected = children.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) + } + } + } + + /** + * Builds a clustering tree by re-indexing internal and leaf clusters. + * @param clusters a map from cluster indices to corresponding cluster summaries + * @return the root node of the clustering tree + */ + private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + var leafIndex = 0 + var internalIndex = -1 + + /** + * Builds a subtree from this given node index. + */ + def buildSubTree(rawIndex: Long): ClusteringTreeNode = { + val cluster = clusters(rawIndex) + val size = cluster.size + val center = cluster.center + val cost = cluster.cost + val isInternal = clusters.contains(leftChildIndex(rawIndex)) + if (isInternal) { + val index = internalIndex + internalIndex -= 1 + val leftIndex = leftChildIndex(rawIndex) + val rightIndex = rightChildIndex(rawIndex) + val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + KMeans.fastSquaredDistance(center, clusters(childIndex).center) + }.max) + val left = buildSubTree(leftIndex) + val right = buildSubTree(rightIndex) + new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + } else { + val index = leafIndex + leafIndex += 1 + val height = 0.0 + new ClusteringTreeNode(index, size, center, cost, height, Array.empty) + } + } + + buildSubTree(ROOT_INDEX) + } + + /** + * Summary of a cluster. + * + * @param size the number of points within this cluster + * @param center the center of the points within this cluster + * @param cost the sum of squared distances to the center + */ + private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double) +} + +/** + * Represents a node in a clustering tree. + * + * @param index node index, negative for internal nodes and non-negative for leaf nodes + * @param size size of the cluster + * @param centerWithNorm cluster center with norm + * @param cost cost of the cluster, i.e., the sum of squared distances to the center + * @param height height of the node in the dendrogram. Currently this is defined as the max distance + * from the center to the centers of the children's, but subject to change. + * @param children children nodes + */ +@Since("1.6.0") +@Experimental +private[clustering] class ClusteringTreeNode private[clustering] ( + val index: Int, + val size: Long, + private val centerWithNorm: VectorWithNorm, + val cost: Double, + val height: Double, + val children: Array[ClusteringTreeNode]) extends Serializable { + + /** Whether this is a leaf node. */ + val isLeaf: Boolean = children.isEmpty + + require((isLeaf && index >= 0) || (!isLeaf && index < 0)) + + /** Cluster center. */ + def center: Vector = centerWithNorm.vector + + /** Predicts the leaf cluster node index that the input point belongs to. */ + def predict(point: Vector): Int = { + val (index, _) = predict(new VectorWithNorm(point)) + index + } + + /** Returns the full prediction path from root to leaf. */ + def predictPath(point: Vector): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point)).toArray + } + + /** Returns the full prediction path from root to leaf. */ + private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + if (isLeaf) { + this :: Nil + } else { + val selected = children.minBy { child => + KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + } + selected :: selected.predictPath(pointWithNorm) + } + } + + /** + * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + */ + def computeCost(point: Vector): Double = { + val (_, cost) = predict(new VectorWithNorm(point)) + cost + } + + /** + * Predicts the cluster index and the cost of the input point. + */ + private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { + predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + } + + /** + * Predicts the cluster index and the cost of the input point. + * @param pointWithNorm input point + * @param cost the cost to the current center + * @return (predicted leaf cluster index, cost) + */ + private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + if (isLeaf) { + (index, cost) + } else { + val (selectedChild, minCost) = children.map { child => + (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + }.minBy(_._2) + selectedChild.predict(pointWithNorm, minCost) + } + } + + /** + * Returns all leaf nodes from this node. + */ + def leafNodes: Array[ClusteringTreeNode] = { + if (isLeaf) { + Array(this) + } else { + children.flatMap(_.leafNodes) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala new file mode 100644 index 000000000000..9ccf96b9395b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * Clustering model produced by [[BisectingKMeans]]. + * The prediction is done level-by-level from the root node to a leaf node, and at each node among + * its children the closest to the input point is selected. + * + * @param root the root node of the clustering tree + */ +@Since("1.6.0") +@Experimental +class BisectingKMeansModel private[clustering] ( + private[clustering] val root: ClusteringTreeNode + ) extends Serializable with Logging { + + /** + * Leaf cluster centers. + */ + @Since("1.6.0") + def clusterCenters: Array[Vector] = root.leafNodes.map(_.center) + + /** + * Number of leaf clusters. + */ + lazy val k: Int = clusterCenters.length + + /** + * Predicts the index of the cluster that the input point belongs to. + */ + @Since("1.6.0") + def predict(point: Vector): Int = { + root.predict(point) + } + + /** + * Predicts the indices of the clusters that the input points belong to. + */ + @Since("1.6.0") + def predict(points: RDD[Vector]): RDD[Int] = { + points.map { p => root.predict(p) } + } + + /** + * Java-friendly version of [[predict()]]. + */ + @Since("1.6.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + + /** + * Computes the squared distance between the input point and the cluster center it belongs to. + */ + @Since("1.6.0") + def computeCost(point: Vector): Double = { + root.computeCost(point) + } + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("1.6.0") + def computeCost(data: RDD[Vector]): Double = { + data.map(root.computeCost).sum() + } + + /** + * Java-friendly version of [[computeCost()]]. + */ + @Since("1.6.0") + def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index e459367333d2..7b203e2f4081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -30,8 +30,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: - * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights @@ -52,7 +50,7 @@ import org.apache.spark.util.Utils * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ -@Experimental +@Since("1.3.0") class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, @@ -63,6 +61,7 @@ class GaussianMixture private ( * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ + @Since("1.3.0") def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -72,10 +71,12 @@ class GaussianMixture private ( // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - /** Set the initial GMM starting point, bypassing the random initialization. - * You must call setK() prior to calling this method, and the condition - * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + /** + * Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException */ + @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { if (model.k == k) { initialModel = Some(model) @@ -85,31 +86,47 @@ class GaussianMixture private ( this } - /** Return the user supplied initial GMM, if supplied */ + /** + * Return the user supplied initial GMM, if supplied + */ + @Since("1.3.0") def getInitialModel: Option[GaussianMixtureModel] = initialModel - /** Set the number of Gaussians in the mixture model. Default: 2 */ + /** + * Set the number of Gaussians in the mixture model. Default: 2 + */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this } - /** Return the number of Gaussians in the mixture model */ + /** + * Return the number of Gaussians in the mixture model + */ + @Since("1.3.0") def getK: Int = k - /** Set the maximum number of iterations to run. Default: 100 */ + /** + * Set the maximum number of iterations to run. Default: 100 + */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Return the maximum number of iterations to run */ + /** + * Return the maximum number of iterations to run + */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this @@ -119,18 +136,28 @@ class GaussianMixture private ( * Return the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def getConvergenceTol: Double = convergenceTol - /** Set the random seed */ + /** + * Set the random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this } - /** Return the random seed */ + /** + * Return the random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Perform expectation maximization */ + /** + * Perform expectation maximization + */ + @Since("1.3.0") def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -140,9 +167,7 @@ class GaussianMixture private ( // Get length of the input vectors val d = breezeData.first().length - // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when - // d > 25 except for when k is very small - val distributeGaussians = ((k - 1.0) / k) * d > 25 + val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d) // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise @@ -176,15 +201,15 @@ class GaussianMixture private ( // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum - if (distributeGaussians) { + if (shouldDistributeGaussians) { val numPartitions = math.min(k, 1024) val tuples = Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i))) val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) => updateWeightsAndGaussians(mean, sigma, weight, sumWeights) - }.collect.unzip - Array.copy(ws, 0, weights, 0, ws.length) - Array.copy(gs, 0, gaussians, 0, gs.length) + }.collect().unzip + Array.copy(ws.toArray, 0, weights, 0, ws.length) + Array.copy(gs.toArray, 0, gaussians, 0, gs.length) } else { var i = 0 while (i < k) { @@ -204,7 +229,10 @@ class GaussianMixture private ( new GaussianMixtureModel(weights, gaussians) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) private def updateWeightsAndGaussians( @@ -239,6 +267,16 @@ class GaussianMixture private ( } } +private[clustering] object GaussianMixture { + /** + * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when + * d > 25 except for when k is very small. + * @param k Number of topics + * @param d Number of features + */ + def shouldDistributeGaussians(k: Int, d: Int): Boolean = ((k - 1.0) / k) * d > 25 +} + // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c803810..74d13e4f7794 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -33,8 +33,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} /** - * :: Experimental :: - * * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. @@ -44,29 +42,48 @@ import org.apache.spark.sql.{SQLContext, Row} * @param gaussians Array of MultivariateGaussian where gaussians(i) represents * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ -@Experimental -class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { +@Since("1.3.0") +class GaussianMixtureModel @Since("1.3.0") ( + @Since("1.3.0") val weights: Array[Double], + @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) } - /** Number of gaussians in mixture */ + /** + * Number of gaussians in mixture + */ + @Since("1.3.0") def k: Int = weights.length - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.3.0") def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - /** Java-friendly version of [[predict()]] */ + /** + * Maps given point to its cluster index. + */ + @Since("1.5.0") + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + + /** + * Java-friendly version of [[predict()]] + */ + @Since("1.4.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -74,6 +91,7 @@ class GaussianMixtureModel( * Given the input vectors, return the membership value of each vector * to all mixture components. */ + @Since("1.3.0") def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) @@ -83,6 +101,14 @@ class GaussianMixtureModel( } } + /** + * Given the input vector, return the membership values to all mixture components. + */ + @Since("1.4.0") + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + /** * Compute the partial assignments for each vector */ @@ -102,7 +128,7 @@ class GaussianMixtureModel( } } -@Experimental +@Since("1.4.0") object GaussianMixtureModel extends Loader[GaussianMixtureModel] { private object SaveLoadV1_0 { @@ -119,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -136,22 +162,22 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) - val dataArray = dataFrame.select("weight", "mu", "sigma").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() val (weights, gaussians) = dataArray.map { case Row(weight: Double, mu: Vector, sigma: Matrix) => (weight, new MultivariateGaussian(mu, sigma)) }.unzip - return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + new GaussianMixtureModel(weights.toArray, gaussians.toArray) } } + @Since("1.4.0") override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0a65403f4ec9..2895db7c9061 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -37,6 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given * to it should be cached by the user. */ +@Since("0.8.0") class KMeans private ( private var k: Int, private var maxIterations: Int, @@ -50,14 +51,19 @@ class KMeans private ( * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ + @Since("0.8.0") def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). */ + @Since("1.4.0") def getK: Int = k - /** Set the number of clusters to create (k). Default: 2. */ + /** + * Set the number of clusters to create (k). Default: 2. + */ + @Since("0.8.0") def setK(k: Int): this.type = { this.k = k this @@ -66,9 +72,13 @@ class KMeans private ( /** * Maximum number of iterations to run. */ + @Since("1.4.0") def getMaxIterations: Int = maxIterations - /** Set maximum number of iterations to run. Default: 20. */ + /** + * Set maximum number of iterations to run. Default: 20. + */ + @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -77,6 +87,7 @@ class KMeans private ( /** * The initialization algorithm. This can be either "random" or "k-means||". */ + @Since("1.4.0") def getInitializationMode: String = initializationMode /** @@ -84,6 +95,7 @@ class KMeans private ( * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ + @Since("0.8.0") def setInitializationMode(initializationMode: String): this.type = { KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode @@ -94,7 +106,8 @@ class KMeans private ( * :: Experimental :: * Number of runs of the algorithm to execute in parallel. */ - @Experimental + @Since("1.4.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") def getRuns: Int = runs /** @@ -103,7 +116,8 @@ class KMeans private ( * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. */ - @Experimental + @Since("0.8.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") @@ -115,12 +129,14 @@ class KMeans private ( /** * Number of steps for the k-means|| initialization mode */ + @Since("1.4.0") def getInitializationSteps: Int = initializationSteps /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. */ + @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") @@ -132,12 +148,14 @@ class KMeans private ( /** * The distance threshold within which we've consider centers to have converged. */ + @Since("1.4.0") def getEpsilon: Double = epsilon /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. */ + @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this @@ -146,9 +164,13 @@ class KMeans private ( /** * The random seed for cluster initialization. */ + @Since("1.4.0") def getSeed: Long = seed - /** Set the random seed for cluster initialization. */ + /** + * Set the random seed for cluster initialization. + */ + @Since("1.4.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -163,6 +185,7 @@ class KMeans private ( * The condition model.k == this.k must be met, failure results * in an IllegalArgumentException. */ + @Since("1.4.0") def setInitialModel(model: KMeansModel): this.type = { require(model.k == k, "mismatched cluster count") initialModel = Some(model) @@ -173,6 +196,7 @@ class KMeans private ( * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ + @Since("0.8.0") def run(data: RDD[Vector]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { @@ -345,7 +369,7 @@ class KMeans private ( : Array[Array[VectorWithNorm]] = { // Initialize empty centers and point costs. val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) - var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() @@ -370,21 +394,28 @@ class KMeans private ( val bcNewCenters = data.context.broadcast(newCenters) val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - Vectors.dense( Array.tabulate(runs) { r => math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) - }) - }.cache() + } + }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs - .aggregate(Vectors.zeros(runs))( + .aggregate(new Array[Double](runs))( seqOp = (s, v) => { // s += v - axpy(1.0, v, s) + var r = 0 + while (r < runs) { + s(r) += v(r) + r += 1 + } s }, combOp = (s0, s1) => { // s0 += s1 - axpy(1.0, s1, s0) + var r = 0 + while (r < runs) { + s0(r) += s1(r) + r += 1 + } s0 } ) @@ -431,10 +462,13 @@ class KMeans private ( /** * Top-level methods for calling K-means clustering. */ +@Since("0.8.0") object KMeans { // Initialization mode names + @Since("0.8.0") val RANDOM = "random" + @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" /** @@ -447,6 +481,7 @@ object KMeans { * @param initializationMode initialization model, either "random" or "k-means||" (default). * @param seed random seed value for cluster initialization */ + @Since("1.3.0") def train( data: RDD[Vector], k: Int, @@ -471,6 +506,7 @@ object KMeans { * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -487,6 +523,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -497,6 +534,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 8ecb3df11d95..91fa9b0d3590 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable @@ -35,28 +36,44 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel ( - val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { +@Since("0.8.0") +class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) + extends Saveable with Serializable with PMMLExportable { - /** A Java-friendly constructor that takes an Iterable of Vectors. */ + /** + * A Java-friendly constructor that takes an Iterable of Vectors. + */ + @Since("1.4.0") def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) - /** Total number of clusters. */ + /** + * Total number of clusters. + */ + @Since("0.8.0") def k: Int = clusterCenters.length - /** Returns the cluster index that a given point belongs to. */ + /** + * Returns the cluster index that a given point belongs to. + */ + @Since("0.8.0") def predict(point: Vector): Int = { KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(centersWithNorm) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -64,6 +81,7 @@ class KMeansModel ( * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. */ + @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(centersWithNorm) @@ -73,6 +91,7 @@ class KMeansModel ( private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { KMeansModel.SaveLoadV1_0.save(sc, this, path) } @@ -80,7 +99,10 @@ class KMeansModel ( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object KMeansModel extends Loader[KMeansModel] { + + @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { KMeansModel.SaveLoadV1_0.load(sc, path) } @@ -102,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -115,16 +137,16 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.read.parquet(Loader.dataPath(path)) - Loader.checkSchema[Cluster](centriods.schema) - val localCentriods = centriods.map(Cluster.apply).collect() - assert(k == localCentriods.size) - new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) + val centroids = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.map(Cluster.apply).collect() + assert(k == localCentroids.size) + new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ab124e6d77c5..eb802a365ed6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -28,8 +28,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: - * * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. * * Terminology: @@ -44,7 +42,7 @@ import org.apache.spark.util.Utils * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] */ -@Experimental +@Since("1.3.0") class LDA private ( private var k: Int, private var maxIterations: Int, @@ -54,19 +52,26 @@ class LDA private ( private var checkpointInterval: Int, private var ldaOptimizer: LDAOptimizer) extends Logging { + /** + * Constructs a LDA instance with default parameters. + */ + @Since("1.3.0") def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. + * */ + @Since("1.3.0") def getK: Int = k /** * Number of topics to infer. I.e., the number of soft cluster centers. * (default = 10) */ + @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") this.k = k @@ -79,7 +84,26 @@ class LDA private ( * * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Vector = this.docConcentration + @Since("1.5.0") + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + @Since("1.3.0") + def getDocConcentration: Double = { + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter + } else { + require(docConcentration.toArray.forall(_ == parameter)) + parameter + } + } /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' @@ -105,24 +129,44 @@ class LDA private ( * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. */ + @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Replicates Double to create a symmetric prior */ + /** + * Replicates a [[Double]] docConcentration to create a symmetric prior. + */ + @Since("1.3.0") def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } - /** Alias for [[getDocConcentration]] */ - def getAlpha: Vector = getDocConcentration + /** + * Alias for [[getAsymmetricDocConcentration]] + */ + @Since("1.5.0") + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[getDocConcentration]] + */ + @Since("1.3.0") + def getAlpha: Double = getDocConcentration + + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.5.0") def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.3.0") def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) /** @@ -134,6 +178,7 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ + @Since("1.3.0") def getTopicConcentration: Double = this.topicConcentration /** @@ -158,35 +203,50 @@ class LDA private ( * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. */ + @Since("1.3.0") def setTopicConcentration(topicConcentration: Double): this.type = { this.topicConcentration = topicConcentration this } - /** Alias for [[getTopicConcentration]] */ + /** + * Alias for [[getTopicConcentration]] + */ + @Since("1.3.0") def getBeta: Double = getTopicConcentration - /** Alias for [[setTopicConcentration()]] */ + /** + * Alias for [[setTopicConcentration()]] + */ + @Since("1.3.0") def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Maximum number of iterations for learning. * (default = 20) */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -195,6 +255,7 @@ class LDA private ( /** * Period (in iterations) between checkpoints. */ + @Since("1.3.0") def getCheckpointInterval: Int = checkpointInterval /** @@ -205,6 +266,7 @@ class LDA private ( * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ + @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -216,6 +278,7 @@ class LDA private ( * * LDAOptimizer used to perform the actual calculation */ + @Since("1.4.0") @DeveloperApi def getOptimizer: LDAOptimizer = ldaOptimizer @@ -224,6 +287,7 @@ class LDA private ( * * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) */ + @Since("1.4.0") @DeveloperApi def setOptimizer(optimizer: LDAOptimizer): this.type = { this.ldaOptimizer = optimizer @@ -234,6 +298,7 @@ class LDA private ( * Set the LDAOptimizer used to perform the actual calculation by algorithm name. * Currently "em", "online" are supported. */ + @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = optimizerName.toLowerCase match { @@ -254,6 +319,7 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ + @Since("1.3.0") def run(documents: RDD[(Long, Vector)]): LDAModel = { val state = ldaOptimizer.initialize(documents, this) var iter = 0 @@ -268,7 +334,10 @@ class LDA private ( state.getLDAModel(iterationTimes) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6af90d7287ff..7384d065a2ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -25,8 +25,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -35,20 +35,20 @@ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue /** - * :: Experimental :: - * * Latent Dirichlet Allocation (LDA) model. * * This abstraction permits for different underlying representations, * including local and distributed data structures. */ -@Experimental +@Since("1.3.0") abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ + @Since("1.3.0") def k: Int /** Vocabulary size (number of terms or terms in the vocabulary) */ + @Since("1.3.0") def vocabSize: Int /** @@ -57,6 +57,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a Dirichlet distribution. */ + @Since("1.5.0") def docConcentration: Vector /** @@ -68,6 +69,7 @@ abstract class LDAModel private[clustering] extends Saveable { * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ + @Since("1.5.0") def topicConcentration: Double /** @@ -81,6 +83,7 @@ abstract class LDAModel private[clustering] extends Saveable { * This is a matrix of size vocabSize x k, where each column is a topic. * No guarantees are given about the ordering of the topics. */ + @Since("1.3.0") def topicsMatrix: Matrix /** @@ -91,6 +94,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] /** @@ -102,6 +106,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) /* TODO (once LDA can be trained with Strings or given a dictionary) @@ -176,27 +181,29 @@ abstract class LDAModel private[clustering] extends Saveable { } /** - * :: Experimental :: - * * Local LDA model. * This model stores only the inferred topics. - * It may be used for computing topics for new documents, but it may give less accurate answers - * than the [[DistributedLDAModel]]. + * * @param topics Inferred topics (vocabSize x k matrix). */ -@Experimental -class LocalLDAModel private[clustering] ( - val topics: Matrix, - override val docConcentration: Vector, - override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { - +@Since("1.3.0") +class LocalLDAModel private[spark] ( + @Since("1.3.0") val topics: Matrix, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, + override protected[spark] val gammaShape: Double = 100) + extends LDAModel with Serializable { + + @Since("1.3.0") override def k: Int = topics.numCols + @Since("1.3.0") override def vocabSize: Int = topics.numRows + @Since("1.3.0") override def topicsMatrix: Matrix = topics + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => @@ -209,6 +216,7 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -217,26 +225,44 @@ class LocalLDAModel private[clustering] ( // TODO: declare in LDAModel and override once implemented in DistributedLDAModel /** * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in original Online LDA paper. + * * @param documents test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ - def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents, + @Since("1.5.0") + def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) /** - * Calculate an upper bound bound on perplexity. See Equation (16) in original Online - * LDA paper. + * Java-friendly version of [[logLikelihood]] + */ + @Since("1.5.0") + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in original Online LDA paper. + * * @param documents test corpus to use for calculating perplexity - * @return variational upper bound on log perplexity per word + * @return Variational upper bound on log perplexity per token. */ + @Since("1.5.0") def logPerplexity(documents: RDD[(Long, Vector)]): Double = { - val corpusWords = documents + val corpusTokenCount = documents .map { case (_, termCounts) => termCounts.toArray.sum } .sum() - val perWordBound = -logLikelihood(documents) / corpusWords + -logLikelihood(documents) / corpusTokenCount + } - perWordBound + /** Java-friendly version of [[logPerplexity]] */ + @Since("1.5.0") + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } /** @@ -244,17 +270,20 @@ class LocalLDAModel private[clustering] ( * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] * This bound is derived by decomposing the LDA model to: * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p) - * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper. + * and noting that the KL-divergence D(q|p) >= 0. + * + * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of + * the original LDA paper. * @param documents a subset of the test corpus * @param alpha document-topic Dirichlet prior parameters - * @param eta topic-word Dirichlet prior parameters + * @param eta topic-word Dirichlet prior parameter * @param lambda parameters for variational q(beta | lambda) topic-word distributions * @param gammaShape shape parameter for random initialization of variational q(theta | gamma) * topic mixture distributions * @param k number of topics * @param vocabSize number of unique terms in the entire test corpus */ - private def bound( + private def logLikelihoodBound( documents: RDD[(Long, Vector)], alpha: Vector, eta: Double, @@ -266,33 +295,38 @@ class LocalLDAModel private[clustering] ( // transpose because dirichletExpectation normalizes by row and we need to normalize // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t + val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + + // Sum bound components for each document: + // component for prob(tokens) + component for prob(document-topic distribution) + val corpusPart = + documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => + val localElogbeta = ElogbetaBc.value + var docBound = 0.0D + val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) + + // E[log p(doc | theta, beta)] + termCounts.foreachActive { case (idx, count) => + docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) + } + // E[log p(theta | alpha) - log q(theta | gamma)] + docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) + docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) - var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => - var docScore = 0.0D - val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) - val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) - - // E[log p(doc | theta, beta)] - termCounts.foreachActive { case (idx, count) => - docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) - } - // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector - docScore += sum((brzAlpha - gammad) :* Elogthetad) - docScore += sum(lgamma(gammad) - lgamma(brzAlpha)) - docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) - - docScore - }.sum() - - // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar - score += sum((eta - lambda) :* Elogbeta) - score += sum(lgamma(lambda) - lgamma(eta)) + docBound + }.sum() + // Bound component for prob(topic-term distributions): + // E[log p(beta | eta) - log q(beta | lambda)] val sumEta = eta * vocabSize - score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) + val topicsPart = sum((eta - lambda) :* Elogbeta) + + sum(lgamma(lambda) - lgamma(eta)) + + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) - score + corpusPart + topicsPart } /** @@ -305,22 +339,24 @@ class LocalLDAModel private[clustering] ( * @param documents documents to predict topic mixture distributions for * @return An RDD of (document ID, topic mixture distribution for document) */ + @Since("1.3.0") // TODO: declare in LDAModel and override once implemented in DistributedLDAModel def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { // Double transpose because dirichletExpectation normalizes by row and we need to normalize // by topic (columns of lambda) val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta) val docConcentrationBrz = this.docConcentration.toBreeze val gammaShape = this.gammaShape val k = this.k documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { - (id, Vectors.zeros(k)) + (id, Vectors.zeros(k)) } else { val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, - expElogbeta, + expElogbetaBc.value, docConcentrationBrz, gammaShape, k) @@ -329,10 +365,42 @@ class LocalLDAModel private[clustering] ( } } -} + /** Get a method usable as a UDF for [[topicDistributions()]] */ + private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = sc.broadcast(expElogbeta) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + (termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + Vectors.zeros(k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbetaBc.value, + docConcentrationBrz, + gammaShape, + k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } + +} @Experimental +@Since("1.5.0") object LocalLDAModel extends Loader[LocalLDAModel] { private object SaveLoadV1_0 { @@ -384,7 +452,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() val vocabSize = topics(0).getAs[Vector](0).size - val k = topics.size + val k = topics.length val brzTopics = BDM.zeros[Double](vocabSize, k) topics.foreach { case Row(vec: Vector, ind: Int) => @@ -392,11 +460,11 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } val topicsMat = Matrices.fromBreeze(brzTopics) - // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } + @Since("1.5.0") override def load(sc: SparkContext, path: String): LocalLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -428,23 +496,20 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } /** - * :: Experimental :: - * * Distributed LDA model. * This model stores the inferred topics, the full training dataset, and the topic distributions. - * When computing topics for new documents, it may give more accurate answers - * than the [[LocalLDAModel]]. */ -@Experimental +@Since("1.3.0") class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], private[clustering] val globalTopicTotals: LDA.TopicCounts, - val k: Int, - val vocabSize: Int, - override val docConcentration: Vector, - override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { + @Since("1.3.0") val k: Int, + @Since("1.3.0") val vocabSize: Int, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ @@ -453,6 +518,7 @@ class DistributedLDAModel private[clustering] ( * The local model stores the inferred topics but not the topic distributions for training * documents. */ + @Since("1.3.0") def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -463,6 +529,7 @@ class DistributedLDAModel private[clustering] ( * * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. */ + @Since("1.3.0") override lazy val topicsMatrix: Matrix = { // Collect row-major topics val termTopicCounts: Array[(Int, TopicCounts)] = @@ -481,6 +548,7 @@ class DistributedLDAModel private[clustering] ( Matrices.fromBreeze(brzTopics) } + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights @@ -520,6 +588,7 @@ class DistributedLDAModel private[clustering] ( * (IDs for the documents, weights of the topic in these documents). * For each topic, documents are sorted in order of decreasing topic weights. */ + @Since("1.5.0") def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { val numTopics = k val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = @@ -546,6 +615,52 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most + * likely topic generating each term? + * + * @return RDD of (doc ID, assignment of top topic index for each term), + * where the assignment is specified via a pair of zippable arrays + * (term indices, topic indices). Note that terms will be omitted if not present in + * the document. + */ + @Since("1.5.0") + lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { + // For reference, compare the below code with the core part of EMLDAOptimizer.next(). + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration(0) + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit = + (edgeContext) => { + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions). + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) + // For this (doc j, term w), send top topic k to doc vertex. + val topTopic: Int = argmax(scaledTopicDistribution) + val term: Int = index2term(edgeContext.dstId) + edgeContext.sendToSrc((Array(term), Array(topTopic))) + } + val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) = + (terms_topics0, terms_topics1) => { + (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val perDocAssignments = + graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex) + perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) => + // TODO: Avoid zip, which is inefficient. + val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip + (docID, sortedTerms.toArray, sortedTopics.toArray) + } + } + + /** Java-friendly version of [[topicAssignments]] */ + @Since("1.5.0") + lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { + topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -559,6 +674,7 @@ class DistributedLDAModel private[clustering] ( * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the * hyperparameters. */ + @Since("1.3.0") lazy val logLikelihood: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -583,8 +699,9 @@ class DistributedLDAModel private[clustering] ( /** * Log probability of the current parameter estimate: - * log P(topics, topic distributions for docs | alpha, eta) + * log P(topics, topic distributions for docs | alpha, eta) */ + @Since("1.3.0") lazy val logPrior: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -616,13 +733,17 @@ class DistributedLDAModel private[clustering] ( * * @return RDD of (document ID, topic distribution) pairs */ + @Since("1.3.0") def topicDistributions: RDD[(Long, Vector)] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) } } - /** Java-friendly version of [[topicDistributions]] */ + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } @@ -631,6 +752,7 @@ class DistributedLDAModel private[clustering] ( * For each document, return the top k weighted topics for that document and their weights. * @return RDD of (doc ID, topic indices, topic weights) */ + @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => val topIndices = argtopk(topicCounts, k) @@ -644,11 +766,24 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Java-friendly version of [[topTopicsPerDocument]] + */ + @Since("1.5.0") + def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? override protected def formatVersion = "1.0" + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, @@ -658,6 +793,7 @@ class DistributedLDAModel private[clustering] ( @Experimental +@Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { private object SaveLoadV1_0 { @@ -744,11 +880,12 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, gammaShape, iterationTimes) + docConcentration, topicConcentration, iterationTimes, gammaShape) } } + @Since("1.5.0") override def load(sc: SparkContext, path: String): DistributedLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -762,10 +899,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { - case (className, "1.0") if className == classNameV1_0 => { + case (className, "1.0") if className == classNameV1_0 => DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray, gammaShape) - } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index b0e14cb8296a..17c0609800e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, su import breeze.numerics.{trigamma, abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer @@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can * hold optimizer-specific parameters for users to set. */ +@Since("1.4.0") @DeveloperApi sealed trait LDAOptimizer { @@ -73,8 +74,8 @@ sealed trait LDAOptimizer { * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. - * */ +@Since("1.4.0") @DeveloperApi final class EMLDAOptimizer extends LDAOptimizer { @@ -95,10 +96,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = lda.getDocConcentration(0) - require({ - lda.getDocConcentration.toArray.forall(_ == docConcentration) - }, "EMLDAOptimizer currently only supports symmetric document-topic priors") + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -168,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer { edgeContext.sendToDst((false, scaledTopicDistribution)) edgeContext.sendToSrc((false, scaledTopicDistribution)) } - // This is a hack to detect whether we could modify the values in-place. + // The Boolean is a hack to detect whether we could modify the values in-place. // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = (m0, m1) => { @@ -209,11 +208,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal - // conversion + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - 100, iterationTimes) + iterationTimes) } } @@ -228,6 +227,7 @@ final class EMLDAOptimizer extends LDAOptimizer { * Original Online LDA paper: * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. */ +@Since("1.4.0") @DeveloperApi final class OnlineLDAOptimizer extends LDAOptimizer { @@ -258,7 +258,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private var tau0: Double = 1024 private var kappa: Double = 0.51 private var miniBatchFraction: Double = 0.05 - private var optimizeAlpha: Boolean = false + private var optimizeDocConcentration: Boolean = false // internal data structure private var docs: RDD[(Long, Vector)] = null @@ -277,6 +277,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. */ + @Since("1.4.0") def getTau0: Double = this.tau0 /** @@ -284,6 +285,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * iterations count less. * Default: 1024, following the original Online LDA paper. */ + @Since("1.4.0") def setTau0(tau0: Double): this.type = { require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 @@ -293,6 +295,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Learning rate: exponential decay rate */ + @Since("1.4.0") def getKappa: Double = this.kappa /** @@ -300,6 +303,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * (0.5, 1.0] to guarantee asymptotic convergence. * Default: 0.51, based on the original Online LDA paper. */ + @Since("1.4.0") def setKappa(kappa: Double): this.type = { require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") this.kappa = kappa @@ -309,6 +313,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration */ + @Since("1.4.0") def getMiniBatchFraction: Double = this.miniBatchFraction /** @@ -321,6 +326,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * * Default: 0.05, i.e., 5% of total documents. */ + @Since("1.4.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") @@ -329,18 +335,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) - * will be optimized during training. + * Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. */ - def getOptimzeAlpha: Boolean = this.optimizeAlpha + @Since("1.5.0") + def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration /** - * Sets whether to optimize alpha parameter during training. + * Sets whether to optimize docConcentration parameter during training. * * Default: false */ - def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { - this.optimizeAlpha = optimizeAlpha + @Since("1.5.0") + def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = { + this.optimizeDocConcentration = optimizeDocConcentration this } @@ -378,18 +386,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration.size == 1) { - if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") - Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") - lda.getDocConcentration.foreachActive { case (_, x) => + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } - lda.getDocConcentration + lda.getAsymmetricDocConcentration } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) @@ -419,6 +429,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val k = this.k val vocabSize = this.vocabSize val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t + val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.toBreeze val gammaShape = this.gammaShape @@ -427,26 +438,27 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val stat = BDM.zeros[Double](k, vocabSize) var gammaPart = List[BDV[Double]]() - nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + nonEmptyDocs.foreach { case (_, termCounts: Vector) => val ids: List[Int] = termCounts match { case v: DenseVector => (0 until v.size).toList case v: SparseVector => v.indices.toList } val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbeta, alpha, gammaShape, k) + termCounts, expElogbetaBc.value, alpha, gammaShape, k) stat(::, ids) := stat(::, ids).toDenseMatrix + sstats gammaPart = gammad :: gammaPart } Iterator((stat, gammaPart)) } val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeAlpha) updateAlpha(gammat) + if (optimizeDocConcentration) updateAlpha(gammat) this } @@ -540,21 +552,22 @@ private[clustering] object OnlineLDAOptimizer { val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids - var meanchange = 1D + val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { + while (meanGammaChange > 1e-3) { val lastgamma = gammad.copy // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) - phinorm := expElogbetad * expElogthetad :+ 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k + // TODO: Keep more values in log space, and only exponentiate when needed. + phiNorm := expElogbetad * expElogthetad :+ 1e-100 + meanGammaChange = sum(abs(gammad - lastgamma)) / k } - val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix (gammad, sstatsd) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index f7e5ce1665fe..a9ba7b60bad0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -22,7 +22,7 @@ import breeze.numerics._ /** * Utility methods for LDA. */ -object LDAUtils { +private[clustering] object LDAUtils { /** * Log Sum Exp with overflow protection using the identity: * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 407e43a024a2..bb1804505948 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -21,7 +21,7 @@ import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl @@ -33,18 +33,18 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.{Logging, SparkContext, SparkException} /** - * :: Experimental :: - * * Model produced by [[PowerIterationClustering]]. * * @param k number of clusters * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ -@Experimental -class PowerIterationClusteringModel( - val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { +@Since("1.3.0") +class PowerIterationClusteringModel @Since("1.3.0") ( + @Since("1.3.0") val k: Int, + @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment]) + extends Saveable with Serializable { + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) } @@ -52,7 +52,10 @@ class PowerIterationClusteringModel( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + + @Since("1.4.0") override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) } @@ -65,8 +68,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -77,9 +81,10 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode dataRDD.write.parquet(Loader.dataPath(path)) } + @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) @@ -99,8 +104,6 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode } /** - * :: Experimental :: - * * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise @@ -112,7 +115,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ -@Experimental +@Since("1.3.0") class PowerIterationClustering private[clustering] ( private var k: Int, private var maxIterations: Int, @@ -120,14 +123,17 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ - /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + /** + * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. */ + @Since("1.3.0") def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this @@ -136,6 +142,7 @@ class PowerIterationClustering private[clustering] ( /** * Set maximum number of iterations of the power iteration loop */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -145,6 +152,7 @@ class PowerIterationClustering private[clustering] ( * Set the initialization mode. This can be either "random" to use a random vector * as vertex properties, or "degree" to use normalized sum similarities. Default: random. */ + @Since("1.3.0") def setInitializationMode(mode: String): this.type = { this.initMode = mode match { case "random" | "degree" => mode @@ -165,6 +173,7 @@ class PowerIterationClustering private[clustering] ( * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ + @Since("1.5.0") def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { val w = normalize(graph) val w0 = initMode match { @@ -186,6 +195,7 @@ class PowerIterationClustering private[clustering] ( * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ + @Since("1.3.0") def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) val w0 = initMode match { @@ -198,6 +208,7 @@ class PowerIterationClustering private[clustering] ( /** * A Java-friendly version of [[PowerIterationClustering.run]]. */ + @Since("1.3.0") def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) : PowerIterationClusteringModel = { run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) @@ -221,16 +232,15 @@ class PowerIterationClustering private[clustering] ( } } -@Experimental +@Since("1.3.0") object PowerIterationClustering extends Logging { /** - * :: Experimental :: * Cluster assignment. * @param id node id * @param cluster assigned cluster id */ - @Experimental + @Since("1.3.0") case class Assignment(id: Long, cluster: Int) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d9b34cec6489..80843719f50b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -30,8 +30,6 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: Experimental :: - * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -63,14 +61,17 @@ import org.apache.spark.util.random.XORShiftRandom * such that at time t + h the discount applied to the data from t is 0.5. * The definition remains the same whether the time unit is given * as batches or points. - * */ -@Experimental -class StreamingKMeansModel( - override val clusterCenters: Array[Vector], - val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { +@Since("1.2.0") +class StreamingKMeansModel @Since("1.2.0") ( + @Since("1.2.0") override val clusterCenters: Array[Vector], + @Since("1.2.0") val clusterWeights: Array[Double]) + extends KMeansModel(clusterCenters) with Logging { - /** Perform a k-means update on a batch of data. */ + /** + * Perform a k-means update on a batch of data. + */ + @Since("1.2.0") def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { // find nearest cluster to each point @@ -82,6 +83,7 @@ class StreamingKMeansModel( (p1._1, p1._2 + p2._2) } val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) .collect() @@ -144,8 +146,6 @@ class StreamingKMeansModel( } /** - * :: Experimental :: - * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -162,29 +162,39 @@ class StreamingKMeansModel( * .trainOn(DStream) * }}} */ -@Experimental -class StreamingKMeans( - var k: Int, - var decayFactor: Double, - var timeUnit: String) extends Logging with Serializable { +@Since("1.2.0") +class StreamingKMeans @Since("1.2.0") ( + @Since("1.2.0") var k: Int, + @Since("1.2.0") var decayFactor: Double, + @Since("1.2.0") var timeUnit: String) extends Logging with Serializable { + @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) - /** Set the number of clusters. */ + /** + * Set the number of clusters. + */ + @Since("1.2.0") def setK(k: Int): this.type = { this.k = k this } - /** Set the decay factor directly (for forgetful algorithms). */ + /** + * Set the decay factor directly (for forgetful algorithms). + */ + @Since("1.2.0") def setDecayFactor(a: Double): this.type = { this.decayFactor = a this } - /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + /** + * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + */ + @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) @@ -195,7 +205,10 @@ class StreamingKMeans( this } - /** Specify initial centers directly. */ + /** + * Specify initial centers directly. + */ + @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { model = new StreamingKMeansModel(centers, weights) this @@ -208,6 +221,7 @@ class StreamingKMeans( * @param weight Weight for each center * @param seed Random seed */ + @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) @@ -216,7 +230,10 @@ class StreamingKMeans( this } - /** Return the latest model. */ + /** + * Return the latest model. + */ + @Since("1.2.0") def latestModel(): StreamingKMeansModel = { model } @@ -229,6 +246,7 @@ class StreamingKMeans( * * @param data DStream containing vector data */ + @Since("1.2.0") def trainOn(data: DStream[Vector]) { assertInitialized() data.foreachRDD { (rdd, time) => @@ -236,7 +254,10 @@ class StreamingKMeans( } } - /** Java-friendly version of `trainOn`. */ + /** + * Java-friendly version of `trainOn`. + */ + @Since("1.4.0") def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) /** @@ -245,12 +266,16 @@ class StreamingKMeans( * @param data DStream containing vector data * @return DStream containing predictions */ + @Since("1.2.0") def predictOn(data: DStream[Vector]): DStream[Int] = { assertInitialized() data.map(model.predict) } - /** Java-friendly version of `predictOn`. */ + /** + * Java-friendly version of `predictOn`. + */ + @Since("1.4.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) } @@ -262,12 +287,16 @@ class StreamingKMeans( * @tparam K key type * @return DStream containing the input keys and the predictions as values */ + @Since("1.2.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { assertInitialized() data.mapValues(model.predict) } - /** Java-friendly version of `predictOnValues`. */ + /** + * Java-friendly version of `predictOnValues`. + */ + @Since("1.4.0") def predictOnValues[K]( data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { implicit val tag = fakeClassTag[K] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index c1d1a224817e..12cf22095720 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -17,15 +17,13 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.Logging -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql.DataFrame /** - * :: Experimental :: * Evaluator for binary classification. * * @param scoreAndLabels an RDD of (score, label) pairs. @@ -42,16 +40,17 @@ import org.apache.spark.sql.DataFrame * be smaller as a result, meaning there may be an extra sample at * partition boundaries. */ -@Experimental -class BinaryClassificationMetrics( - val scoreAndLabels: RDD[(Double, Double)], - val numBins: Int) extends Logging { +@Since("1.0.0") +class BinaryClassificationMetrics @Since("1.3.0") ( + @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], + @Since("1.3.0") val numBins: Int) extends Logging { require(numBins >= 0, "numBins must be nonnegative") /** * Defaults `numBins` to 0. */ + @Since("1.0.0") def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) /** @@ -61,12 +60,18 @@ class BinaryClassificationMetrics( private[mllib] def this(scoreAndLabels: DataFrame) = this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * Unpersist intermediate RDDs used in the computation. + */ + @Since("1.0.0") def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + */ + @Since("1.0.0") def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -75,6 +80,7 @@ class BinaryClassificationMetrics( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.0.0") def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) val sc = confusions.context @@ -86,6 +92,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. */ + @Since("1.0.0") def areaUnderROC(): Double = AreaUnderCurve.of(roc()) /** @@ -93,6 +100,7 @@ class BinaryClassificationMetrics( * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall */ + @Since("1.0.0") def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) val sc = confusions.context @@ -103,6 +111,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. */ + @Since("1.0.0") def areaUnderPR(): Double = AreaUnderCurve.of(pr()) /** @@ -111,15 +120,25 @@ class BinaryClassificationMetrics( * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score */ + @Since("1.0.0") def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + */ + @Since("1.0.0") def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + */ + @Since("1.0.0") def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + */ + @Since("1.0.0") def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 4628dc569091..c5104960cfcb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.evaluation import scala.collection.Map -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -31,8 +30,8 @@ import org.apache.spark.sql.DataFrame * * @param predictionAndLabels an RDD of (prediction, label) pairs. */ -@Experimental -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { +@Since("1.1.0") +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { /** * An auxiliary constructor taking a DataFrame. @@ -65,6 +64,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * they are ordered by class label ascending, * as in "labels" */ + @Since("1.1.0") def confusionMatrix: Matrix = { val n = labels.size val values = Array.ofDim[Double](n * n) @@ -84,12 +84,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns true positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) fp.toDouble / (labelCount - labelCountByClass(label)) @@ -99,6 +101,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns precision for a given label (category) * @param label the label. */ + @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) val fp = fpByClass.getOrElse(label, 0) @@ -109,6 +112,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns recall for a given label (category) * @param label the label. */ + @Since("1.1.0") def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) /** @@ -116,6 +120,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * @param label the label. * @param beta the beta parameter. */ + @Since("1.1.0") def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) val r = recall(label) @@ -127,11 +132,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f1-measure for a given label (category) * @param label the label. */ + @Since("1.1.0") def fMeasure(label: Double): Double = fMeasure(label, 1.0) /** * Returns precision */ + @Since("1.1.0") lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount /** @@ -140,23 +147,27 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * because sum of all false positives is equal to sum * of all false negatives) */ + @Since("1.1.0") lazy val recall: Double = precision /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ + @Since("1.1.0") lazy val fMeasure: Double = precision /** * Returns weighted true positive rate * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedTruePositiveRate: Double = weightedRecall /** * Returns weighted false positive rate */ + @Since("1.1.0") lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => falsePositiveRate(category) * count.toDouble / labelCount }.sum @@ -165,6 +176,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged recall * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => recall(category) * count.toDouble / labelCount }.sum @@ -172,6 +184,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged precision */ + @Since("1.1.0") lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => precision(category) * count.toDouble / labelCount }.sum @@ -180,6 +193,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged f-measure * @param beta the beta parameter. */ + @Since("1.1.0") def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount }.sum @@ -187,6 +201,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f1-measure */ + @Since("1.1.0") lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => fMeasure(category, 1.0) * count.toDouble / labelCount }.sum @@ -194,5 +209,6 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns the sequence of labels in ascending order */ + @Since("1.1.0") lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index bf6eb1d5bd2a..c100b3c9ec14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.sql.DataFrame @@ -26,7 +27,8 @@ import org.apache.spark.sql.DataFrame * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. */ -class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { +@Since("1.2.0") +class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double], Array[Double])]) { /** * An auxiliary constructor taking a DataFrame. @@ -44,6 +46,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns subset accuracy * (for equal sets of labels) */ + @Since("1.2.0") lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => predictions.deep == labels.deep }.count().toDouble / numDocs @@ -51,6 +54,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns accuracy */ + @Since("1.2.0") lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs @@ -59,6 +63,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns Hamming-loss */ + @Since("1.2.0") lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => labels.size + predictions.size - 2 * labels.intersect(predictions).size }.sum / (numDocs * numLabels) @@ -66,6 +71,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based precision averaged by the number of documents */ + @Since("1.2.0") lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => if (predictions.size > 0) { predictions.intersect(labels).size.toDouble / predictions.size @@ -77,6 +83,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based recall averaged by the number of documents */ + @Since("1.2.0") lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.size }.sum / numDocs @@ -84,6 +91,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based f1-measure averaged by the number of documents */ + @Since("1.2.0") lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) }.sum / numDocs @@ -104,6 +112,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns precision for a given label (category) * @param label the label. */ + @Since("1.2.0") def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) @@ -114,6 +123,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns recall for a given label (category) * @param label the label. */ + @Since("1.2.0") def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) @@ -124,6 +134,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns f1-measure for a given label (category) * @param label the label. */ + @Since("1.2.0") def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) @@ -138,6 +149,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ + @Since("1.2.0") lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) @@ -147,6 +159,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ + @Since("1.2.0") lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) @@ -156,10 +169,12 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ + @Since("1.2.0") lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order */ + @Since("1.2.0") lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 5b5a2a1450f7..cc01936dd34b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD @@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ -@Experimental +@Since("1.2.0") class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) extends Logging with Serializable { @@ -56,6 +56,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions */ + @Since("1.2.0") def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -125,6 +126,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions */ + @Since("1.2.0") def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -156,13 +158,13 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] } -@Experimental object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs */ + @Since("1.4.0") def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 408847afa800..1d8f4fe340fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors @@ -25,13 +25,13 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate import org.apache.spark.sql.DataFrame /** - * :: Experimental :: * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. */ -@Experimental -class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { +@Since("1.2.0") +class RegressionMetrics @Since("1.2.0") ( + predictionAndObservations: RDD[(Double, Double)]) extends Logging { /** * An auxiliary constructor taking a DataFrame. @@ -67,6 +67,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ + @Since("1.2.0") def explainedVariance: Double = { SSreg / summary.count } @@ -75,6 +76,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. */ + @Since("1.2.0") def meanAbsoluteError: Double = { summary.normL1(1) / summary.count } @@ -83,6 +85,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. */ + @Since("1.2.0") def meanSquaredError: Double = { SSerr / summary.count } @@ -91,6 +94,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the root mean squared error, which is defined as the square root of * the mean squared error. */ + @Since("1.2.0") def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) } @@ -99,6 +103,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ + @Since("1.2.0") def r2: Double = { 1 - SSerr / SStot } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 5f8c1dea237b..eaa99cfe82e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,20 +19,27 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.annotation.Experimental +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{SQLContext, Row} /** - * :: Experimental :: * Chi Squared selector model. * * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ -@Experimental -class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { +@Since("1.3.0") +class ChiSqSelectorModel @Since("1.3.0") ( + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -52,6 +59,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.3.0") override def transform(vector: Vector): Vector = { compress(vector, selectedFeatures) } @@ -99,16 +107,79 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf s"Only sparse and dense vectors are supported but got ${other.getClass}.") } } + + @Since("1.6.0") + override def save(sc: SparkContext, path: String): Unit = { + ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { + @Since("1.6.0") + override def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + ChiSqSelectorModel.SaveLoadV1_0.load(sc, path) + } + + private[feature] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(feature: Int) + + private[feature] + val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" + + def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(model.selectedFeatures.length) { i => + Data(model.selectedFeatures(i)) + } + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + + } + + def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val dataFrame = sqlContext.read.parquet(Loader.dataPath(path)) + val dataArray = dataFrame.select("feature") + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val features = dataArray.map { + case Row(feature: Int) => (feature) + }.collect() + + return new ChiSqSelectorModel(features) + } + } } /** - * :: Experimental :: * Creates a ChiSquared feature selector. * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) + * Note that if the number of features is < numTopFeatures, then this will + * select all features. */ -@Experimental -class ChiSqSelector (val numTopFeatures: Int) extends Serializable { +@Since("1.3.0") +class ChiSqSelector @Since("1.3.0") ( + @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. @@ -117,6 +188,7 @@ class ChiSqSelector (val numTopFeatures: Int) extends Serializable { * Real-valued features will be treated as categorical for each distinct value. * Apply feature discretizer before using this function. */ + @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val indices = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index d67fe6c3ee4f..c757fc7f06c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ /** - * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. * @param scalingVec The values used to scale the reference vector's individual components. */ -@Experimental -class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { +@Since("1.4.0") +class ElementwiseProduct @Since("1.4.0") ( + @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -36,6 +36,7 @@ class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { require(vector.size == scalingVec.size, s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c53475818395..c93ed64183ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,31 +22,35 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * * @param numFeatures number of features (default: 2^20^) */ -@Experimental +@Since("1.1.0") class HashingTF(val numFeatures: Int) extends Serializable { + /** + */ + @Since("1.1.0") def this() = this(1 << 20) /** * Returns the index of the input term. */ + @Since("1.1.0") def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) /** * Transforms the input document into a sparse term frequency vector. */ + @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -59,6 +63,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document into a sparse term frequency vector (Java version). */ + @Since("1.1.0") def transform(document: JavaIterable[_]): Vector = { transform(document.asScala) } @@ -66,6 +71,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors. */ + @Since("1.1.0") def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { dataset.map(this.transform) } @@ -73,6 +79,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors (Java version). */ + @Since("1.1.0") def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { dataset.rdd.map(this.transform).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3fab7ea79bef..cffa9fba05c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Inverse document frequency (IDF). * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total * number of documents and `d(t)` is the number of documents that contain term `t`. @@ -37,9 +36,10 @@ import org.apache.spark.rdd.RDD * @param minDocFreq minimum of documents in which a term * should appear for filtering */ -@Experimental -class IDF(val minDocFreq: Int) { +@Since("1.1.0") +class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) { + @Since("1.1.0") def this() = this(0) // TODO: Allow different IDF formulations. @@ -48,6 +48,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: RDD[Vector]): IDFModel = { val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( minDocFreq = minDocFreq))( @@ -61,6 +62,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } @@ -155,11 +157,10 @@ private object IDF { } /** - * :: Experimental :: * Represents an IDF model that can transform term frequency vectors. */ -@Experimental -class IDFModel private[spark] (val idf: Vector) extends Serializable { +@Since("1.1.0") +class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -171,6 +172,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: RDD[Vector]): RDD[Vector] = { val bcIdf = dataset.context.broadcast(idf) dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) @@ -182,6 +184,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param v a term frequency vector * @return a TF-IDF vector */ + @Since("1.3.0") def transform(v: Vector): Vector = IDFModel.transform(idf, v) /** @@ -189,6 +192,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset a JavaRDD of term frequency vectors * @return a JavaRDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { transform(dataset.rdd).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 32848e039eb8..af0c8e1d8a9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** - * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -31,9 +30,10 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors * * @param p Normalization in L^p^ space, p = 2 by default. */ -@Experimental -class Normalizer(p: Double) extends VectorTransformer { +@Since("1.1.0") +class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { + @Since("1.1.0") def this() = this(2) require(p >= 1.0) @@ -44,6 +44,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @param vector vector to be normalized. * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { val norm = Vectors.norm(vector, p) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 2a66263d8b7d..24e0a98c39bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.feature +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -27,7 +28,8 @@ import org.apache.spark.rdd.RDD * * @param k number of principal components */ -class PCA(val k: Int) { +@Since("1.4.0") +class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") /** @@ -35,12 +37,14 @@ class PCA(val k: Int) { * * @param sources source vectors */ + @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { require(k <= sources.first().size, s"source vector size is ${sources.first().size} must be greater than k=$k") val mat = new RowMatrix(sources) - val pc = mat.computePrincipalComponents(k) match { + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) + val densePC = pc match { case dm: DenseMatrix => dm case sm: SparseMatrix => @@ -55,10 +59,19 @@ class PCA(val k: Int) { s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") } - new PCAModel(k, pc) + val denseExplainedVariance = explainedVariance match { + case dv: DenseVector => + dv + case sv: SparseVector => + sv.toDense + } + new PCAModel(k, densePC, denseExplainedVariance) } - /** Java-friendly version of [[fit()]] */ + /** + * Java-friendly version of [[fit()]] + */ + @Since("1.4.0") def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) } @@ -68,7 +81,11 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +@Since("1.4.0") +class PCAModel private[spark] ( + @Since("1.4.0") val k: Int, + @Since("1.4.0") val pc: DenseMatrix, + @Since("1.6.0") val explainedVariance: DenseVector) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * @@ -76,6 +93,7 @@ class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTr * Vector must be the same length as the source vectors given to [[PCA.fit()]]. * @return transformed vector. Vector will be of length k. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { vector match { case dv: DenseVector => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index c73b8f258060..6fe573c52894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,13 +18,12 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * @@ -32,9 +31,10 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ -@Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { +@Since("1.1.0") +class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging { + @Since("1.1.0") def this() = this(false, true) if (!(withMean || withStd)) { @@ -47,6 +47,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param data The data used to compute the mean and variance to build the transformation model. * @return a StandardScalarModel */ + @Since("1.1.0") def fit(data: RDD[Vector]): StandardScalerModel = { // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( @@ -61,7 +62,6 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { } /** - * :: Experimental :: * Represents a StandardScaler model that can transform vectors. * * @param std column standard deviation values @@ -69,13 +69,16 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param withStd whether to scale the data to have unit standard deviation * @param withMean whether to center the data before scaling */ -@Experimental -class StandardScalerModel ( - val std: Vector, - val mean: Vector, - var withStd: Boolean, - var withMean: Boolean) extends VectorTransformer { +@Since("1.1.0") +class StandardScalerModel @Since("1.3.0") ( + @Since("1.3.0") val std: Vector, + @Since("1.1.0") val mean: Vector, + @Since("1.3.0") var withStd: Boolean, + @Since("1.3.0") var withMean: Boolean) extends VectorTransformer { + /** + */ + @Since("1.3.0") def this(std: Vector, mean: Vector) { this(std, mean, withStd = std != null, withMean = mean != null) require(this.withStd || this.withMean, @@ -86,8 +89,10 @@ class StandardScalerModel ( } } + @Since("1.3.0") def this(std: Vector) = this(std, null) + @Since("1.3.0") @DeveloperApi def setWithMean(withMean: Boolean): this.type = { require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") @@ -95,6 +100,7 @@ class StandardScalerModel ( this } + @Since("1.3.0") @DeveloperApi def setWithStd(withStd: Boolean): this.type = { require(!(withStd && this.std == null), @@ -115,6 +121,7 @@ class StandardScalerModel ( * @return Standardized vector. If the std of a column is zero, it will return default `0.0` * for the column with zero std. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 7358c1c84f79..5778fd1d0925 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for transformation of a vector */ +@Since("1.1.0") @DeveloperApi trait VectorTransformer extends Serializable { @@ -35,6 +36,7 @@ trait VectorTransformer extends Serializable { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.1.0") def transform(vector: Vector): Vector /** @@ -43,6 +45,7 @@ trait VectorTransformer extends Serializable { * @param data RDD[Vector] to be transformed. * @return transformed RDD[Vector]. */ + @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. // So it should be no longer necessary to explicitly broadcast `this` object. @@ -55,6 +58,7 @@ trait VectorTransformer extends Serializable { * @param data JavaRDD[Vector] to be transformed. * @return transformed JavaRDD[Vector]. */ + @Since("1.1.0") def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { transform(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index cbbd2b0c8d06..1f400e1430eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -31,15 +31,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.SQLContext /** * Entry in vocabulary @@ -53,7 +52,6 @@ private case class VocabWord( ) /** - * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus * and then learns vector representation of words in the vocabulary. @@ -70,7 +68,7 @@ private case class VocabWord( * and * Distributed Representations of Words and Phrases and their Compositionality. */ -@Experimental +@Since("1.1.0") class Word2Vec extends Serializable with Logging { private var vectorSize = 100 @@ -83,6 +81,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets vector size (default: 100). */ + @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { this.vectorSize = vectorSize this @@ -91,6 +90,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets initial learning rate (default: 0.025). */ + @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { this.learningRate = learningRate this @@ -99,6 +99,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets number of partitions (default: 1). Use a small number for accuracy. */ + @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") this.numPartitions = numPartitions @@ -109,6 +110,7 @@ class Word2Vec extends Serializable with Logging { * Sets number of iterations (default: 1), which should be smaller than or equal to number of * partitions. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.numIterations = numIterations this @@ -117,15 +119,26 @@ class Word2Vec extends Serializable with Logging { /** * Sets random seed (default: a random long integer). */ + @Since("1.1.0") def setSeed(seed: Long): this.type = { this.seed = seed this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ + @Since("1.3.0") def setMinCount(minCount: Int): this.type = { this.minCount = minCount this @@ -137,12 +150,12 @@ class Word2Vec extends Serializable with Logging { private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) @@ -263,6 +276,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset an RDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -301,22 +315,25 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + - "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") } val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -400,6 +417,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() @@ -412,13 +431,13 @@ class Word2Vec extends Serializable with Logging { * @param dataset a JavaRDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { fit(dataset.rdd.map(_.asScala)) } } /** - * :: Experimental :: * Word2Vec model * @param wordIndex maps each word to an index, which can retrieve the corresponding * vector from wordVectors @@ -426,10 +445,10 @@ class Word2Vec extends Serializable with Logging { * to the word mapped with index i can be retrieved by the slice * (i * vectorSize, i * vectorSize + vectorSize) */ -@Experimental -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +@Since("1.1.0") +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. @@ -454,6 +473,7 @@ class Word2VecModel private[mllib] ( wordVecNorms } + @Since("1.5.0") def this(model: Map[String, Array[Float]]) = { this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } @@ -469,6 +489,7 @@ class Word2VecModel private[mllib] ( override protected def formatVersion = "1.0" + @Since("1.4.0") def save(sc: SparkContext, path: String): Unit = { Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } @@ -478,6 +499,7 @@ class Word2VecModel private[mllib] ( * @param word a word * @return vector representation of word */ + @Since("1.1.0") def transform(word: String): Vector = { wordIndex.get(word) match { case Some(ind) => @@ -494,6 +516,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) findSynonyms(vector, num) @@ -505,6 +528,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k @@ -534,6 +558,7 @@ class Word2VecModel private[mllib] ( /** * Returns a map of words to their vector representations. */ + @Since("1.2.0") def getVectors: Map[String, Array[Float]] = { wordIndex.map { case (word, ind) => (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) @@ -541,7 +566,7 @@ class Word2VecModel private[mllib] ( } } -@Experimental +@Since("1.4.0") object Word2VecModel extends Loader[Word2VecModel] { private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { @@ -571,35 +596,42 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) - - val dataArray = dataFrame.select("word", "vector").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("word", "vector").collect() val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap new Word2VecModel(word2VecMap) } def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val vectorSize = model.values.head.size val numWords = model.size - val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ - ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + val metadata = compact(render( + ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + // We want to partition the model in partitions of size 32MB + val partitionSize = (1L << 25) + // We calculate the approximate size of the model + // We only calculate the array size, not considering + // the string size, the formula is: + // floatSize * numWords * vectorSize + val approxSize = 4L * numWords * vectorSize + val nPartitions = ((approxSize / partitionSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) } } + @Since("1.4.0") override def load(sc: SparkContext, path: String): Word2VecModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 72d0ea0c12e1..07eb750b06a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.mllib.fpm +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.AssociationRules.Rule @@ -32,24 +33,22 @@ import org.apache.spark.rdd.RDD * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates * association rules which have a single item as the consequent. * - * @since 1.5.0 */ +@Since("1.5.0") @Experimental class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minConfidence = 0.8}. - * - * @since 1.5.0 */ + @Since("1.5.0") def this() = this(0.8) /** * Sets the minimal confidence (default: `0.8`). - * - * @since 1.5.0 */ + @Since("1.5.0") def setMinConfidence(minConfidence: Double): this.type = { require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence @@ -61,8 +60,8 @@ class AssociationRules private[fpm] ( * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] * @return a [[Set[Rule[Item]]] containing the assocation rules. * - * @since 1.5.0 */ + @Since("1.5.0") def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => @@ -83,31 +82,41 @@ class AssociationRules private[fpm] ( }.filter(_.confidence >= minConfidence) } + /** Java-friendly version of [[run]]. */ + @Since("1.5.0") def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { val tag = fakeClassTag[Item] run(freqItemsets.rdd)(tag) } } +@Since("1.5.0") object AssociationRules { /** * :: Experimental :: * * An association rule between sets of items. - * @param antecedent hypotheses of the rule - * @param consequent conclusion of the rule + * @param antecedent hypotheses of the rule. Java users should call [[Rule#javaAntecedent]] + * instead. + * @param consequent conclusion of the rule. Java users should call [[Rule#javaConsequent]] + * instead. * @tparam Item item type * - * @since 1.5.0 */ + @Since("1.5.0") @Experimental class Rule[Item] private[fpm] ( - val antecedent: Array[Item], - val consequent: Array[Item], + @Since("1.5.0") val antecedent: Array[Item], + @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, freqAntecedent: Double) extends Serializable { + /** + * Returns the confidence of the rule. + * + */ + @Since("1.5.0") def confidence: Double = freqUnion.toDouble / freqAntecedent require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { @@ -115,5 +124,28 @@ object AssociationRules { s"A valid association rule must have disjoint antecedent and " + s"consequent but ${sharedItems} is present in both." }) + + /** + * Returns antecedent in a Java List. + * + */ + @Since("1.5.0") + def javaAntecedent: java.util.List[Item] = { + antecedent.toList.asJava + } + + /** + * Returns consequent in a Java List. + * + */ + @Since("1.5.0") + def javaConsequent: java.util.List[Item] = { + consequent.toList.asJava + } + + override def toString: String = { + s"${antecedent.mkString("{", ",", "}")} => " + + s"${consequent.mkString("{", ",", "}")}: ${confidence}" + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index e2370a52f493..70ef1ed30c71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ @@ -33,21 +33,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: - * * Model trained by [[FPGrowth]], which holds frequent itemsets. * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type - * - * @since 1.3.0 */ -@Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { +@Since("1.3.0") +class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced - * @since 1.5.0 */ + @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) @@ -55,8 +52,6 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex } /** - * :: Experimental :: - * * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query * Recommendation]]. PFP distributes computation in such a way that each worker executes an @@ -71,9 +66,8 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning * (Wikipedia)]] * - * @since 1.3.0 */ -@Experimental +@Since("1.3.0") class FPGrowth private ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { @@ -82,15 +76,15 @@ class FPGrowth private ( * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same * as the input data}. * - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(0.3, -1) /** * Sets the minimal support level (default: `0.3`). * - * @since 1.3.0 */ + @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { this.minSupport = minSupport this @@ -99,8 +93,8 @@ class FPGrowth private ( /** * Sets the number of partitions used by parallel FP-growth (default: same as input data). * - * @since 1.3.0 */ + @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { this.numPartitions = numPartitions this @@ -111,8 +105,8 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] * - * @since 1.3.0 */ + @Since("1.3.0") def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") @@ -126,6 +120,8 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + /** Java-friendly version of [[run]]. */ + @Since("1.3.0") def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { implicit val tag = fakeClassTag[Item] run(data.rdd.map(_.asScala.toArray)) @@ -210,12 +206,7 @@ class FPGrowth private ( } } -/** - * :: Experimental :: - * - * @since 1.3.0 - */ -@Experimental +@Since("1.3.0") object FPGrowth { /** @@ -224,15 +215,17 @@ object FPGrowth { * @param freq frequency * @tparam Item item type * - * @since 1.3.0 */ - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + @Since("1.3.0") + class FreqItemset[Item] @Since("1.3.0") ( + @Since("1.3.0") val items: Array[Item], + @Since("1.3.0") val freq: Long) extends Serializable { /** * Returns items in a Java List. * - * @since 1.3.0 */ + @Since("1.3.0") def javaItems: java.util.List[Item] = { items.toList.asJava } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index ccebf951c850..3ea10779a183 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -22,85 +22,89 @@ import scala.collection.mutable import org.apache.spark.Logging /** - * Calculate all patterns of a projected database in local. + * Calculate all patterns of a projected database in local mode. + * + * @param minCount minimal count for a frequent pattern + * @param maxPatternLength max pattern length for a frequent pattern */ -private[fpm] object LocalPrefixSpan extends Logging with Serializable { - import PrefixSpan._ +private[fpm] class LocalPrefixSpan( + val minCount: Long, + val maxPatternLength: Int) extends Logging with Serializable { + import PrefixSpan.Postfix + import LocalPrefixSpan.ReversedPrefix + /** - * Calculate all patterns of a projected database. - * @param minCount minimum count - * @param maxPatternLength maximum pattern length - * @param prefixes prefixes in reversed order - * @param database the projected database - * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items in reversed order), - * the value of pair is the pattern's count. + * Generates frequent patterns on the input array of postfixes. + * @param postfixes an array of postfixes + * @return an iterator of (frequent pattern, count) */ - def run( - minCount: Long, - maxPatternLength: Int, - prefixes: List[Set[Int]], - database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]], Long)] = { - if (prefixes.length == maxPatternLength || database.isEmpty) { - return Iterator.empty - } - val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database) - val freqItems = freqItemSetsAndCounts.keys.flatten.toSet - val filteredDatabase = database.map { suffix => - suffix - .map(item => freqItems.intersect(item)) - .filter(_.nonEmpty) - } - freqItemSetsAndCounts.iterator.flatMap { case (item, count) => - val newPrefixes = item :: prefixes - val newProjected = project(filteredDatabase, item) - Iterator.single((newPrefixes, count)) ++ - run(minCount, maxPatternLength, newPrefixes, newProjected) + def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = { + genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) => + (prefix.toSequence, count) } } /** - * Calculate suffix sequence immediately after the first occurrence of an item. - * @param item itemset to get suffix after - * @param sequence sequence to extract suffix from - * @return suffix sequence + * Recursively generates frequent patterns. + * @param prefix current prefix + * @param postfixes projected postfixes w.r.t. the prefix + * @return an iterator of (prefix, count) */ - def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]] = { - val itemsetSeq = sequence - val index = itemsetSeq.indexWhere(item.subsetOf(_)) - if (index == -1) { - List() - } else { - itemsetSeq.drop(index + 1) + private def genFreqPatterns( + prefix: ReversedPrefix, + postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = { + if (maxPatternLength == prefix.length || postfixes.length < minCount) { + return Iterator.empty + } + // find frequent items + val counts = mutable.Map.empty[Int, Long].withDefaultValue(0) + postfixes.foreach { postfix => + postfix.genPrefixItems.foreach { case (x, _) => + counts(x) += 1L + } + } + val freqItems = counts.toSeq.filter { case (_, count) => + count >= minCount + }.sorted + // project and recursively call genFreqPatterns + freqItems.toIterator.flatMap { case (item, count) => + val newPrefix = prefix :+ item + Iterator.single((newPrefix, count)) ++ { + val projected = postfixes.map(_.project(item)).filter(_.nonEmpty) + genFreqPatterns(newPrefix, projected) + } } } +} - def project( - database: Iterable[List[Set[Int]]], - prefix: Set[Int]): Iterable[List[Set[Int]]] = { - database - .map(getSuffix(prefix, _)) - .filter(_.nonEmpty) - } +private object LocalPrefixSpan { /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the minimum count for an item to be frequent - * @param database database of sequences - * @return freq item to count map + * Represents a prefix stored as a list in reversed order. + * @param items items in the prefix in reversed order + * @param length length of the prefix, not counting delimiters */ - private def getFreqItemAndCounts( - minCount: Long, - database: Iterable[List[Set[Int]]]): Map[Set[Int], Long] = { - // TODO: use PrimitiveKeyOpenHashMap - val counts = mutable.Map[Set[Int], Long]().withDefaultValue(0L) - database.foreach { sequence => - sequence.flatMap(nonemptySubsets(_)).distinct.foreach { item => - counts(item) += 1L + class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable { + /** + * Expands the prefix by one item. + */ + def :+(item: Int): ReversedPrefix = { + require(item != 0) + if (item < 0) { + new ReversedPrefix(-item :: items, length + 1) + } else { + new ReversedPrefix(item :: 0 :: items, length + 1) } } - counts - .filter { case (_, count) => count >= minCount } - .toMap + + /** + * Converts this prefix to a sequence. + */ + def toSequence: Array[Int] = (0 :: items).toArray.reverse + } + + object ReversedPrefix { + /** An empty prefix. */ + val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9eaf733fada2..97916daa2e9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -18,64 +18,67 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuilder import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * * :: Experimental :: * - * A parallel PrefixSpan algorithm to mine sequential pattern. - * The PrefixSpan algorithm is described in - * [[http://doi.org/10.1109/ICDE.2001.914830]]. + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). * * @param minSupport the minimal support level of the sequential pattern, any pattern appears * more than (minSupport * size-of-the-dataset) times will be output * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears - * less than maxPatternLength will be output + * less than maxPatternLength will be output + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal + * storage format) allowed in a projected database before local + * processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run. * * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining * (Wikipedia)]] */ @Experimental +@Since("1.5.0") class PrefixSpan private ( private var minSupport: Double, - private var maxPatternLength: Int) extends Logging with Serializable { + private var maxPatternLength: Int, + private var maxLocalProjDBSize: Long) extends Logging with Serializable { import PrefixSpan._ - /** - * The maximum number of items allowed in a projected database before local processing. If a - * projected database exceeds this size, another iteration of distributed PrefixSpan is run. - */ - // TODO: make configurable with a better default value - private val maxLocalProjDBSize: Long = 32000000L - /** * Constructs a default instance with default parameters - * {minSupport: `0.1`, maxPatternLength: `10`}. + * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}. */ - def this() = this(0.1, 10) + @Since("1.5.0") + def this() = this(0.1, 10, 32000000L) /** * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered * frequent). */ - def getMinSupport: Double = this.minSupport + @Since("1.5.0") + def getMinSupport: Double = minSupport /** * Sets the minimal support level (default: `0.1`). */ + @Since("1.5.0") def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") + require(minSupport >= 0 && minSupport <= 1, + s"The minimum support value must be in [0, 1], but got $minSupport.") this.minSupport = minSupport this } @@ -83,45 +86,120 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ - def getMaxPatternLength: Double = this.maxPatternLength + @Since("1.5.0") + def getMaxPatternLength: Int = maxPatternLength /** * Sets maximal pattern length (default: `10`). */ + @Since("1.5.0") def setMaxPatternLength(maxPatternLength: Int): this.type = { // TODO: support unbounded pattern length when maxPatternLength = 0 - require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") + require(maxPatternLength >= 1, + s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.") this.maxPatternLength = maxPatternLength this } /** - * Find the complete set of sequential patterns in the input sequences of itemsets. - * @param data ordered sequences of itemsets. - * @return a [[PrefixSpanModel]] that contains the frequent sequences + * Gets the maximum number of items allowed in a projected database before local processing. + */ + @Since("1.5.0") + def getMaxLocalProjDBSize: Long = maxLocalProjDBSize + + /** + * Sets the maximum number of items (including delimiters used in the internal storage format) + * allowed in a projected database before local processing (default: `32000000L`). + */ + @Since("1.5.0") + def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = { + require(maxLocalProjDBSize >= 0L, + s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize") + this.maxLocalProjDBSize = maxLocalProjDBSize + this + } + + /** + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * @param data sequences of itemsets. + * @return a [[PrefixSpanModel]] that contains the frequent patterns */ + @Since("1.5.0") def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = { - val itemToInt = data.aggregate(Set[Item]())( - seqOp = { (uniqItems, item) => uniqItems ++ item.flatten.toSet }, - combOp = { _ ++ _ } - ).zipWithIndex.toMap - val intToItem = Map() ++ (itemToInt.map { case (k, v) => (v, k) }) - - val dataInternalRepr = data.map { seq => - seq.map(itemset => itemset.map(itemToInt)).reduce((a, b) => a ++ (DELIMITER +: b)) + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") } - val results = run(dataInternalRepr) - def toPublicRepr(pattern: Iterable[Int]): List[Array[Item]] = { - pattern.span(_ != DELIMITER) match { - case (x, xs) if xs.size > 1 => x.map(intToItem).toArray :: toPublicRepr(xs.tail) - case (x, xs) => List(x.map(intToItem).toArray) + val totalCount = data.count() + logInfo(s"number of sequences: $totalCount") + val minCount = math.ceil(minSupport * totalCount).toLong + logInfo(s"minimum count for a frequent pattern: $minCount") + + // Find frequent items. + val freqItemAndCounts = data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach { _.foreach { item => + uniqItems += item + }} + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _) + .filter { case (_, count) => + count >= minCount + }.collect() + val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + logInfo(s"number of frequent items: ${freqItems.length}") + + // Keep only frequent items from input sequences and convert them to internal storage. + val itemToInt = freqItems.zipWithIndex.toMap + val dataInternalRepr = data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + } + allItems += 0 + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + }.persist(StorageLevel.MEMORY_AND_DISK) + + val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) + + def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = { + val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]] + val itemsetBuilder = mutable.ArrayBuilder.make[Item] + val n = pattern.length + var i = 1 + while (i < n) { + val x = pattern(i) + if (x == 0) { + sequenceBuilder += itemsetBuilder.result() + itemsetBuilder.clear() + } else { + itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format + } + i += 1 } + sequenceBuilder.result() } + val freqSequences = results.map { case (seq: Array[Int], count: Long) => - new FreqSequence[Item](toPublicRepr(seq).toArray, count) + new FreqSequence(toPublicRepr(seq), count) } - new PrefixSpanModel[Item](freqSequences) + new PrefixSpanModel(freqSequences) } /** @@ -131,208 +209,335 @@ class PrefixSpan private ( * @tparam Item item type * @tparam Itemset itemset type, which is an Iterable of Items * @tparam Sequence sequence type, which is an Iterable of Itemsets - * @return a [[PrefixSpanModel]] that contains the frequent sequences + * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns */ + @Since("1.5.0") def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]]( data: JavaRDD[Sequence]): PrefixSpanModel[Item] = { implicit val tag = fakeClassTag[Item] run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray)) } +} + +@Experimental +@Since("1.5.0") +object PrefixSpan extends Logging { + /** - * Find the complete set of sequential patterns in the input sequences. This method utilizes - * the internal representation of itemsets as Array[Int] where each itemset is represented by - * a contiguous sequence of non-negative integers and delimiters represented by [[DELIMITER]]. - * @param data ordered sequences of itemsets. Items are represented by non-negative integers. - * Each itemset has one or more items and is delimited by [[DELIMITER]]. - * @return a set of sequential pattern pairs, - * the key of pair is pattern (a list of elements), - * the value of pair is the pattern's count. + * Find the complete set of frequent sequential patterns in the input sequences. + * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], + * where each itemset is represented by a contiguous sequence of distinct and ordered + * positive integers. We use 0 as the delimiter at itemset boundaries, including the + * first and the last position. + * @return an RDD of (frequent sequential pattern, count) pairs, + * @see [[Postfix]] */ - private[fpm] def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + private[fpm] def genFreqPatterns( + data: RDD[Array[Int]], + minCount: Long, + maxPatternLength: Int, + maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = { val sc = data.sparkContext if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - // Use List[Set[Item]] for internal computation - val sequences = data.map { seq => splitSequence(seq.toList) } - - // Convert min support to a min number of transactions for this dataset - val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong - - // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold - val freqItemCounts = sequences - .flatMap(seq => seq.flatMap(nonemptySubsets(_)).distinct.map(item => (item, 1L))) - .reduceByKey(_ + _) - .filter { case (item, count) => (count >= minCount) } - .collect() - .toMap - - // Pairs of (length 1 prefix, suffix consisting of frequent items) - val itemSuffixPairs = { - val freqItemSets = freqItemCounts.keys.toSet - val freqItems = freqItemSets.flatten - sequences.flatMap { seq => - val filteredSeq = seq.map(item => freqItems.intersect(item)).filter(_.nonEmpty) - freqItemSets.flatMap { item => - val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) - candidateSuffix match { - case suffix if !suffix.isEmpty => Some((List(item), suffix)) - case _ => None + val postfixes = data.map(items => new Postfix(items)) + + // Local frequent patterns (prefixes) and their counts. + val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)] + // Prefixes whose projected databases are small. + val smallPrefixes = mutable.Map.empty[Int, Prefix] + val emptyPrefix = Prefix.empty + // Prefixes whose projected databases are large. + var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix) + while (largePrefixes.nonEmpty) { + val numLocalFreqPatterns = localFreqPatterns.length + logInfo(s"number of local frequent patterns: $numLocalFreqPatterns") + if (numLocalFreqPatterns > 1000000) { + logWarning( + s""" + | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider: + | 1. increase minSupport, + | 2. decrease maxPatternLength, + | 3. increase maxLocalProjDBSize. + """.stripMargin) + } + logInfo(s"number of small prefixes: ${smallPrefixes.size}") + logInfo(s"number of large prefixes: ${largePrefixes.size}") + val largePrefixArray = largePrefixes.values.toArray + val freqPrefixes = postfixes.flatMap { postfix => + largePrefixArray.flatMap { prefix => + postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) => + ((prefix.id, item), (1L, postfixSize)) + } + } + }.reduceByKey { case ((c0, s0), (c1, s1)) => + (c0 + c1, s0 + s1) + }.filter { case (_, (c, _)) => c >= minCount } + .collect() + val newLargePrefixes = mutable.Map.empty[Int, Prefix] + freqPrefixes.foreach { case ((id, item), (count, projDBSize)) => + val newPrefix = largePrefixes(id) :+ item + localFreqPatterns += ((newPrefix.items :+ 0, count)) + if (newPrefix.length < maxPatternLength) { + if (projDBSize > maxLocalProjDBSize) { + newLargePrefixes += newPrefix.id -> newPrefix + } else { + smallPrefixes += newPrefix.id -> newPrefix } } } + largePrefixes = newLargePrefixes } - // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. - // frequent length-one prefixes) - var resultsAccumulator = freqItemCounts.map { case (item, count) => (List(item), count) }.toList - - // Remaining work to be locally and distributively processed respectfully - var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) - - // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have - // projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached - var patternLength = 1 - while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) { - val (nextPatternAndCounts, nextPrefixSuffixPairs) = - extendPrefixes(minCount, pairsForDistributed) - pairsForDistributed.unpersist() - val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) - pairsForDistributed = largerPairsPart - pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) - pairsForLocal ++= smallerPairsPart - resultsAccumulator ++= nextPatternAndCounts.collect() - patternLength += 1 // pattern length grows one per iteration + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } + } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern } - // Process the small projected databases locally - val remainingResults = getPatternsInLocal( - minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) - - (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) - .map { case (pattern, count) => (flattenSequence(pattern.reverse).toArray, count) } + freqPatterns } - /** - * Partitions the prefix-suffix pairs by projected database size. - * @param prefixSuffixPairs prefix (length n) and suffix pairs, - * @return prefix-suffix pairs partitioned by whether their projected database size is <= or - * greater than [[maxLocalProjDBSize]] + * Represents a prefix. + * @param items items in this prefix, using the internal format + * @param length length of this prefix, not counting 0 */ - private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) - : (List[(List[Set[Int]], List[Set[Int]])], RDD[(List[Set[Int]], List[Set[Int]])]) = { - val prefixToSuffixSize = prefixSuffixPairs - .aggregateByKey(0)( - seqOp = { case (count, suffix) => count + suffix.length }, - combOp = { _ + _ }) - val smallPrefixes = prefixToSuffixSize - .filter(_._2 <= maxLocalProjDBSize) - .keys - .collect() - .toSet - val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } - val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } - (small.collect().toList, large) + private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable { + + /** A unique id for this prefix. */ + val id: Int = Prefix.nextId + + /** Expands this prefix by the input item. */ + def :+(item: Int): Prefix = { + require(item != 0) + if (item < 0) { + new Prefix(items :+ -item, length + 1) + } else { + new Prefix(items ++ Array(0, item), length + 1) + } + } } - /** - * Extends all prefixes by one itemset from their suffix and computes the resulting frequent - * prefixes and remaining work. - * @param minCount minimum count - * @param prefixSuffixPairs prefix (length N) and suffix pairs, - * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended - * prefix, corresponding suffix) pairs. - */ - private def extendPrefixes( - minCount: Long, - prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) - : (RDD[(List[Set[Int]], Long)], RDD[(List[Set[Int]], List[Set[Int]])]) = { - - // (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences - // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` - val prefixItemPairAndCounts = prefixSuffixPairs - .flatMap { case (prefix, suffix) => - suffix.flatMap(nonemptySubsets(_)).distinct.map(y => ((prefix, y), 1L)) } - .reduceByKey(_ + _) - .filter { case (item, count) => (count >= minCount) } - - // Map from prefix to set of possible next items from suffix - val prefixToNextItems = prefixItemPairAndCounts - .keys - .groupByKey() - .mapValues(_.toSet) - .collect() - .toMap - - // Frequent patterns with length N+1 and their corresponding counts - val extendedPrefixAndCounts = prefixItemPairAndCounts - .map { case ((prefix, item), count) => (item :: prefix, count) } - - // Remaining work, all prefixes will have length N+1 - val extendedPrefixAndSuffix = prefixSuffixPairs - .filter(x => prefixToNextItems.contains(x._1)) - .flatMap { case (prefix, suffix) => - val frequentNextItemSets = prefixToNextItems(prefix) - val frequentNextItems = frequentNextItemSets.flatten - val filteredSuffix = suffix - .map(item => frequentNextItems.intersect(item)) - .filter(_.nonEmpty) - frequentNextItemSets.flatMap { item => - LocalPrefixSpan.getSuffix(item, filteredSuffix) match { - case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) - case _ => None - } - } - } + private[fpm] object Prefix { + /** Internal counter to generate unique IDs. */ + private val counter: AtomicInteger = new AtomicInteger(-1) - (extendedPrefixAndCounts, extendedPrefixAndSuffix) + /** Gets the next unique ID. */ + private def nextId: Int = counter.incrementAndGet() + + /** An empty [[Prefix]] instance. */ + val empty: Prefix = new Prefix(Array.empty, 0) } /** - * Calculate the patterns in local. - * @param minCount the absolute minimum count - * @param data prefixes and projected sequences data data - * @return patterns + * An internal representation of a postfix from some projection. + * We use one int array to store the items, which might also contains other items from the + * original sequence. + * Items are represented by positive integers, and items in each itemset must be distinct and + * ordered. + * we use 0 as the delimiter between itemsets. + * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`. + * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`. + * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix, + * and mark the start index of the postfix, which is `2` in this example. + * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`. + * We also remember the start indices of partial projections, the ones that split an itemset. + * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`. + * We remember the start indices of partial projections, which is `[2, 5]` in this example. + * This data structure makes it easier to do projections. + * + * @param items a sequence stored as `Array[Int]` containing this postfix + * @param start the start index of this postfix in items + * @param partialStarts start indices of possible partial projections, strictly increasing */ - private def getPatternsInLocal( - minCount: Long, - data: RDD[(List[Set[Int]], Iterable[List[Set[Int]]])]): RDD[(List[Set[Int]], Long)] = { - data.flatMap { - case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB) + private[fpm] class Postfix( + val items: Array[Int], + val start: Int = 0, + val partialStarts: Array[Int] = Array.empty) extends Serializable { + + require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.") + if (partialStarts.nonEmpty) { + require(partialStarts.head >= start, + "The first partial start cannot be smaller than the start index," + + s"but got partialStarts.head = ${partialStarts.head} < start = $start.") } - } -} + /** + * Start index of the first full itemset contained in this postfix. + */ + private[this] def fullStart: Int = { + var i = start + while (items(i) != 0) { + i += 1 + } + i + } + + /** + * Generates length-1 prefix items of this postfix with the corresponding postfix sizes. + * There are two types of prefix items: + * a) The item can be assembled to the last itemset of the prefix. For example, + * the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this + * postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and + * `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item. + * b) The item can be appended to the prefix. Taking the same example above, the prefix items + * can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`, + * and `<13>`. + * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it + * indicates a partial prefix item, which should be assembled to the last itemset of the + * current prefix. Otherwise, the item should be appended to the current prefix. + */ + def genPrefixItems: Iterator[(Int, Long)] = { + val n1 = items.length - 1 + // For each unique item (subject to sign) in this sequence, we output exact one split. + // TODO: use PrimitiveKeyOpenHashMap + val prefixes = mutable.Map.empty[Int, Long] + // a) items that can be assembled to the last itemset of the prefix + partialStarts.foreach { start => + var i = start + var x = -items(i) + while (x != 0) { + if (!prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + x = -items(i) + } + } + // b) items that can be appended to the prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x != 0 && !prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + } + prefixes.toIterator + } -object PrefixSpan { - private[fpm] val DELIMITER = -1 + /** Tests whether this postfix is non-empty. */ + def nonEmpty: Boolean = items.length > start + 1 - /** Splits an array of itemsets delimited by [[DELIMITER]]. */ - private[fpm] def splitSequence(sequence: List[Int]): List[Set[Int]] = { - sequence.span(_ != DELIMITER) match { - case (x, xs) if xs.length > 1 => x.toSet :: splitSequence(xs.tail) - case (x, xs) => List(x.toSet) + /** + * Projects this postfix with respect to the input prefix item. + * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it + * is negative, we do partial projections. + * @return the projected postfix + */ + def project(prefix: Int): Postfix = { + require(prefix != 0) + val n1 = items.length - 1 + var matched = false + var newStart = n1 + val newPartialStarts = mutable.ArrayBuilder.make[Int] + if (prefix < 0) { + // Search for partial projections. + val target = -prefix + partialStarts.foreach { start => + var i = start + var x = items(i) + while (x != target && x != 0) { + i += 1 + x = items(i) + } + if (x == target) { + i += 1 + if (!matched) { + newStart = i + matched = true + } + if (items(i) != 0) { + newPartialStarts += i + } + } + } + } else { + // Search for items in full itemsets. + // Though the items are ordered in each itemsets, they should be small in practice. + // So a sequential scan is sufficient here, compared to bisection search. + val target = prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x == target) { + if (!matched) { + newStart = i + matched = true + } + if (items(i + 1) != 0) { + newPartialStarts += i + 1 + } + } + i += 1 + } + } + new Postfix(items, newStart, newPartialStarts.result()) } - } - /** Flattens a sequence of itemsets into an Array, inserting[[DELIMITER]] between itemsets. */ - private[fpm] def flattenSequence(sequence: List[Set[Int]]): List[Int] = { - val builder = ArrayBuilder.make[Int]() - for (itemSet <- sequence) { - builder += DELIMITER - builder ++= itemSet.toSeq.sorted + /** + * Projects this postfix with respect to the input prefix. + */ + private def project(prefix: Array[Int]): Postfix = { + var partial = true + var cur = this + var i = 0 + val np = prefix.length + while (i < np && cur.nonEmpty) { + val x = prefix(i) + if (x == 0) { + partial = false + } else { + if (partial) { + cur = cur.project(-x) + } else { + cur = cur.project(x) + partial = true + } + } + i += 1 + } + cur } - builder.result().toList.drop(1) // drop trailing delimiter - } - /** Returns an iterator over all non-empty subsets of `itemSet` */ - private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = { - // TODO: improve complexity by using partial prefixes, considering one item at a time - itemSet.subsets.filter(_ != Set.empty[Int]) + /** + * Projects this postfix with respect to the input prefix. + */ + def project(prefix: Prefix): Postfix = project(prefix.items) + + /** + * Returns the same sequence with compressed storage if possible. + */ + def compressed: Postfix = { + if (start > 0) { + new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start)) + } else { + this + } + } } /** @@ -341,10 +546,14 @@ object PrefixSpan { * @param freq frequency * @tparam Item item type */ - class FreqSequence[Item](val sequence: Array[Array[Item]], val freq: Long) extends Serializable { + @Since("1.5.0") + class FreqSequence[Item] @Since("1.5.0") ( + @Since("1.5.0") val sequence: Array[Array[Item]], + @Since("1.5.0") val freq: Long) extends Serializable { /** * Returns sequence as a Java List of lists for Java users. */ + @Since("1.5.0") def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava } } @@ -354,5 +563,7 @@ object PrefixSpan { * @param freqSequences frequent sequences * @tparam Item item type */ -class PrefixSpanModel[Item](val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) +@Since("1.5.0") +class PrefixSpanModel[Item] @Since("1.5.0") ( + @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9029093e0fa0..df9f4ae145b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging { } } + /** Y += a * x */ + private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = { + require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " + + s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.") + f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1) + } + /** * dot(x, y) */ @@ -229,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging { _nativeBLAS } + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix in a [[DenseVector]](column major) + */ + def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { + spr(alpha, v, U.values) + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + val n = v.size + v match { + case DenseVector(values) => + NativeBLAS.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -305,6 +356,8 @@ private[spark] object BLAS extends Serializable with Logging { "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0 && beta == 1.0) { logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") + } else if (alpha == 0.0) { + f2jBLAS.dscal(C.values.length, beta, C.values, 1) } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) @@ -408,8 +461,8 @@ private[spark] object BLAS extends Serializable with Logging { } } } else { - // Scale matrix first if `beta` is not equal to 0.0 - if (beta != 0.0) { + // Scale matrix first if `beta` is not equal to 1.0 + if (beta != 1.0) { f2jBLAS.dscal(C.values.length, beta, C.values, 1) } // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of @@ -469,9 +522,11 @@ private[spark] object BLAS extends Serializable with Logging { require(A.numCols == x.size, s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, - s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") - if (alpha == 0.0) { - logDebug("gemv: alpha is equal to 0. Returning y.") + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.") + } else if (alpha == 0.0) { + scal(beta, y) } else { (A, x) match { case (smA: SparseMatrix, dvx: DenseVector) => @@ -526,11 +581,6 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val yValues = y.values - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounterForA = 0 while (rowCounterForA < mA) { @@ -581,11 +631,6 @@ private[spark] object BLAS extends Serializable with Logging { val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { @@ -604,7 +649,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) var colCounterForA = 0 var k = 0 @@ -659,7 +704,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala new file mode 100644 index 000000000000..0cd371e9cce3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +/** + * Compute Cholesky decomposition. + */ +private[spark] object CholeskyDecomposition { + + /** + * Solves a symmetric positive definite linear system via Cholesky factorization. + * The input arguments are modified in-place to store the factorization and the solution. + * @param A the upper triangular part of A + * @param bx right-hand side + * @return the solution array + */ + def solve(A: Array[Double], bx: Array[Double]): Array[Double] = { + val k = bx.size + val info = new intW(0) + lapack.dppsv("U", k, 1, A, bx, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dpotrs returned $code.") + bx + } + + /** + * Computes the inverse of a real symmetric positive definite matrix A + * using the Cholesky factorization A = U**T*U. + * The input arguments are modified in-place to store the inverse matrix. + * @param UAi the upper triangular factor U from the Cholesky factorization A = U**T*U + * @param k the dimension of A + * @return the upper triangle of the (symmetric) inverse of A + */ + def inverse(UAi: Array[Double], k: Int): Array[Double] = { + val info = new intW(0) + lapack.dpptri("U", k, UAi, info) + val code = info.`val` + assert(code == 0, s"lapack.dpptri returned $code.") + UAi + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index ae3ba3099c87..863abe86d38d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -21,13 +21,9 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.ARPACK import org.netlib.util.{intW, doubleW} -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: * Compute eigen-decomposition. */ -@Experimental private[mllib] object EigenValueDecomposition { /** * Compute the leading k eigenvalues and eigenvectors on a symmetric square matrix using ARPACK. @@ -46,7 +42,7 @@ private[mllib] object EigenValueDecomposition { * for more details). The maximum number of Arnoldi update iterations is set to 300 in this * function. */ - private[mllib] def symmetricEigs( + def symmetricEigs( mul: BDV[Double] => BDV[Double], n: Int, k: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1c858348bf20..8879dcf75c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,27 +23,33 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** * Trait for a local matrix. */ @SQLUserDefinedType(udt = classOf[MatrixUDT]) +@Since("1.0.0") sealed trait Matrix extends Serializable { /** Number of rows. */ + @Since("1.0.0") def numRows: Int /** Number of columns. */ + @Since("1.0.0") def numCols: Int /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + @Since("1.3.0") val isTransposed: Boolean = false /** Converts to a dense array in column major. */ + @Since("1.0.0") def toArray: Array[Double] = { val newArray = new Array[Double](numRows * numCols) foreachActive { (i, j, v) => @@ -56,6 +62,7 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ + @Since("1.3.0") def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ @@ -65,12 +72,15 @@ sealed trait Matrix extends Serializable { private[mllib] def update(i: Int, j: Int, v: Double): Unit /** Get a deep copy of the matrix. */ + @Since("1.2.0") def copy: Matrix /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + @Since("1.3.0") def transpose: Matrix /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + @Since("1.2.0") def multiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) BLAS.gemm(1.0, this, y, 0.0, C) @@ -78,11 +88,13 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ + @Since("1.2.0") def multiply(y: DenseVector): DenseVector = { multiply(y.asInstanceOf[Vector]) } /** Convenience method for `Matrix`-`Vector` multiplication. */ + @Since("1.4.0") def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) @@ -93,6 +105,7 @@ sealed trait Matrix extends Serializable { override def toString: String = toBreeze.toString() /** A human readable representation of the matrix with maximum lines and width */ + @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) /** Map the values of this matrix using a function. Generates a new matrix. Performs the @@ -118,11 +131,13 @@ sealed trait Matrix extends Serializable { /** * Find the number of non-zero active values. */ + @Since("1.5.0") def numNonzeros: Int /** * Find the number of values stored explicitly. These values can be zero as well. */ + @Since("1.5.0") def numActives: Int } @@ -228,12 +243,13 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class DenseMatrix( - val numRows: Int, - val numCols: Int, - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class DenseMatrix @Since("1.3.0") ( + @Since("1.0.0") val numRows: Int, + @Since("1.0.0") val numCols: Int, + @Since("1.0.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") @@ -253,12 +269,12 @@ class DenseMatrix( * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.0.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case m: Matrix => toBreeze == m.toBreeze case _ => false } @@ -277,9 +293,12 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) + @Since("1.3.0") override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { + require(i >= 0 && i < numRows, s"Expected 0 <= i < $numRows, got i = $i.") + require(j >= 0 && j < numCols, s"Expected 0 <= j < $numCols, got j = $j.") if (!isTransposed) i + numRows * j else j + numCols * i } @@ -287,6 +306,7 @@ class DenseMatrix( values(index(i, j)) = v } + @Since("1.4.0") override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), @@ -302,6 +322,7 @@ class DenseMatrix( this } + @Since("1.3.0") override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { @@ -332,14 +353,17 @@ class DenseMatrix( } } + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) + @Since("1.5.0") override def numActives: Int = values.length /** * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("1.3.0") def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -367,6 +391,7 @@ class DenseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. */ +@Since("1.3.0") object DenseMatrix { /** @@ -375,6 +400,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ + @Since("1.3.0") def zeros(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -387,6 +413,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ + @Since("1.3.0") def ones(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -398,6 +425,7 @@ object DenseMatrix { * @param n number of rows and columns of the matrix * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def eye(n: Int): DenseMatrix = { val identity = DenseMatrix.zeros(n, n) var i = 0 @@ -415,6 +443,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -428,6 +457,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -440,6 +470,7 @@ object DenseMatrix { * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.3.0") def diag(vector: Vector): DenseMatrix = { val n = vector.size val matrix = DenseMatrix.zeros(n, n) @@ -475,14 +506,15 @@ object DenseMatrix { * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ +@Since("1.2.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class SparseMatrix( - val numRows: Int, - val numCols: Int, - val colPtrs: Array[Int], - val rowIndices: Array[Int], - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class SparseMatrix @Since("1.3.0") ( + @Since("1.2.0") val numRows: Int, + @Since("1.2.0") val numCols: Int, + @Since("1.2.0") val colPtrs: Array[Int], + @Since("1.2.0") val rowIndices: Array[Int], + @Since("1.2.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") @@ -512,6 +544,7 @@ class SparseMatrix( * order for each column * @param values non-zero matrix entries in column major */ + @Since("1.2.0") def this( numRows: Int, numCols: Int, @@ -519,6 +552,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) @@ -528,12 +566,15 @@ class SparseMatrix( } } + @Since("1.3.0") override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { + require(i >= 0 && i < numRows, s"Expected 0 <= i < $numRows, got i = $i.") + require(j >= 0 && j < numCols, s"Expected 0 <= j < $numCols, got j = $j.") if (!isTransposed) { Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) } else { @@ -551,6 +592,7 @@ class SparseMatrix( } } + @Since("1.4.0") override def copy: SparseMatrix = { new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } @@ -568,6 +610,7 @@ class SparseMatrix( this } + @Since("1.3.0") override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) @@ -602,12 +645,15 @@ class SparseMatrix( * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("1.3.0") def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) + @Since("1.5.0") override def numActives: Int = values.length } @@ -615,6 +661,7 @@ class SparseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. */ +@Since("1.3.0") object SparseMatrix { /** @@ -626,6 +673,7 @@ object SparseMatrix { * @param entries Array of (i, j, value) tuples * @return The corresponding `SparseMatrix` */ + @Since("1.3.0") def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) val numEntries = sortedEntries.size @@ -674,6 +722,7 @@ object SparseMatrix { * @param n number of rows and columns of the matrix * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): SparseMatrix = { new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) } @@ -743,6 +792,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextDouble()) @@ -756,6 +806,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextGaussian()) @@ -767,6 +818,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ + @Since("1.3.0") def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { @@ -783,6 +835,7 @@ object SparseMatrix { /** * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]]. */ +@Since("1.0.0") object Matrices { /** @@ -792,6 +845,7 @@ object Matrices { * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.0.0") def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { new DenseMatrix(numRows, numCols, values) } @@ -805,6 +859,7 @@ object Matrices { * @param rowIndices the row index of the entry * @param values non-zero matrix entries in column major */ + @Since("1.2.0") def sparse( numRows: Int, numCols: Int, @@ -838,6 +893,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ + @Since("1.2.0") def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** @@ -846,6 +902,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of ones */ + @Since("1.2.0") def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** @@ -853,6 +910,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.2.0") def eye(n: Int): Matrix = DenseMatrix.eye(n) /** @@ -860,6 +918,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): Matrix = SparseMatrix.speye(n) /** @@ -869,6 +928,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.2.0") def rand(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.rand(numRows, numCols, rng) @@ -880,6 +940,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprand(numRows, numCols, density, rng) @@ -890,6 +951,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.2.0") def randn(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.randn(numRows, numCols, rng) @@ -901,6 +963,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprandn(numRows, numCols, density, rng) @@ -910,6 +973,7 @@ object Matrices { * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.2.0") def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) /** @@ -919,6 +983,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ + @Since("1.3.0") def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) @@ -977,6 +1042,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were vertically concatenated */ + @Since("1.3.0") def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b416d50a5631..4591cb88ef15 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -17,19 +17,19 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** - * :: Experimental :: * Represents singular value decomposition (SVD) factors. */ -@Experimental +@Since("1.0.0") case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) /** * :: Experimental :: * Represents QR factors. */ +@Since("1.5.0") @Experimental -case class QRDecomposition[UType, VType](Q: UType, R: VType) +case class QRDecomposition[QType, RType](Q: QType, R: RType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 96d1f48ba2ba..4dcf351df43f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -24,12 +24,16 @@ import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -38,16 +42,19 @@ import org.apache.spark.sql.types._ * Note: Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) +@Since("1.0.0") sealed trait Vector extends Serializable { /** * Size of the vector. */ + @Since("1.0.0") def size: Int /** * Converts the instance to a double array. */ + @Since("1.0.0") def toArray: Array[Double] override def equals(other: Any): Boolean = { @@ -68,20 +75,22 @@ sealed trait Vector extends Serializable { } /** - * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros - * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + * Returns a hash code value for the vector. The hash code is based on its size and its first 128 + * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. */ override def hashCode(): Int = { // This is a reference implementation. It calls return in foreachActive, which is slow. // Subclasses should override it with optimized implementation. var result: Int = 31 + size + var nnz = 0 this.foreachActive { (index, value) => - if (index < 16) { + if (nnz < Vectors.MAX_HASH_NNZ) { // ignore explicit 0 for comparison between sparse and dense if (value != 0) { result = 31 * result + index val bits = java.lang.Double.doubleToLongBits(value) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } } else { return result @@ -99,11 +108,13 @@ sealed trait Vector extends Serializable { * Gets the value of the ith element. * @param i index */ + @Since("1.1.0") def apply(i: Int): Double = toBreeze(i) /** * Makes a deep copy of this vector. */ + @Since("1.1.0") def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } @@ -115,32 +126,38 @@ sealed trait Vector extends Serializable { * the vector with type `Int`, and the second parameter is the corresponding value * with type `Double`. */ - private[spark] def foreachActive(f: (Int, Double) => Unit) + @Since("1.6.0") + def foreachActive(f: (Int, Double) => Unit): Unit /** * Number of active entries. An "active entry" is an element which is explicitly stored, * regardless of its value. Note that inactive entries have value 0. */ + @Since("1.4.0") def numActives: Int /** * Number of nonzero elements. This scans all active values and count nonzeros. */ + @Since("1.4.0") def numNonzeros: Int /** * Converts this vector to a sparse vector with all explicit zeros removed. */ + @Since("1.4.0") def toSparse: SparseVector /** * Converts this vector to a dense vector. */ + @Since("1.4.0") def toDense: DenseVector = new DenseVector(this.toArray) /** * Returns a vector in either dense or sparse format, whichever uses less storage. */ + @Since("1.4.0") def compressed: Vector = { val nnz = numNonzeros // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. @@ -155,19 +172,24 @@ sealed trait Vector extends Serializable { * Find the index of a maximal element. Returns the first maximal element in case of a tie. * Returns -1 if vector has length 0. */ + @Since("1.5.0") def argmax: Int + + /** + * Converts the vector to a JSON string. + */ + @Since("1.6.0") + def toJson: String } /** - * :: DeveloperApi :: + * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@DeveloperApi -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense @@ -243,11 +265,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { * We don't use the name `Vector` because Scala imports * [[scala.collection.immutable.Vector]] by default. */ +@Since("1.0.0") object Vectors { /** * Creates a dense vector from its values. */ + @Since("1.0.0") @varargs def dense(firstValue: Double, otherValues: Double*): Vector = new DenseVector((firstValue +: otherValues).toArray) @@ -256,6 +280,7 @@ object Vectors { /** * Creates a dense vector from a double array. */ + @Since("1.0.0") def dense(values: Array[Double]): Vector = new DenseVector(values) /** @@ -265,6 +290,7 @@ object Vectors { * @param indices index array, must be strictly increasing. * @param values value array, must have the same length as indices. */ + @Since("1.0.0") def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector = new SparseVector(size, indices, values) @@ -274,6 +300,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { require(size > 0, "The size of the requested sparse vector must be greater than 0.") @@ -295,6 +322,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = { sparse(size, elements.asScala.map { case (i, x) => (i.intValue(), x.doubleValue()) @@ -307,6 +335,7 @@ object Vectors { * @param size vector size * @return a zero vector */ + @Since("1.1.0") def zeros(size: Int): Vector = { new DenseVector(new Array[Double](size)) } @@ -314,10 +343,32 @@ object Vectors { /** * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ + @Since("1.1.0") def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) } + /** + * Parses the JSON representation of a vector into a [[Vector]]. + */ + @Since("1.6.0") + def fromJson(json: String): Vector = { + implicit val formats = DefaultFormats + val jValue = parseJson(json) + (jValue \ "type").extract[Int] match { + case 0 => // sparse + val size = (jValue \ "size").extract[Int] + val indices = (jValue \ "indices").extract[Seq[Int]].toArray + val values = (jValue \ "values").extract[Seq[Double]].toArray + sparse(size, indices, values) + case 1 => // dense + val values = (jValue \ "values").extract[Seq[Double]].toArray + dense(values) + case _ => + throw new IllegalArgumentException(s"Cannot parse $json into a vector.") + } + } + private[mllib] def parseNumeric(any: Any): Vector = { any match { case values: Array[Double] => @@ -357,6 +408,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ + @Since("1.3.0") def norm(vector: Vector, p: Double): Double = { require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + s"You specified p=$p.") @@ -409,6 +461,7 @@ object Vectors { * @param v2 second Vector. * @return squared distance between two Vectors. */ + @Since("1.3.0") def sqdist(v1: Vector, v2: Vector): Double = { require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + s"=${v2.size}.") @@ -517,29 +570,39 @@ object Vectors { } allEqual } + + /** Max number of nonzero entries used in computing hash code. */ + private[linalg] val MAX_HASH_NNZ = 128 } /** * A dense vector represented by a value array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class DenseVector(val values: Array[Double]) extends Vector { +class DenseVector @Since("1.0.0") ( + @Since("1.0.0") val values: Array[Double]) extends Vector { + @Since("1.0.0") override def size: Int = values.length override def toString: String = values.mkString("[", ",", "]") + @Since("1.0.0") override def toArray: Array[Double] = values private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) + @Since("1.0.0") override def apply(i: Int): Double = values(i) + @Since("1.1.0") override def copy: DenseVector = { new DenseVector(values.clone()) } - private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + @Since("1.6.0") + override def foreachActive(f: (Int, Double) => Unit): Unit = { var i = 0 val localValuesSize = values.length val localValues = values @@ -553,21 +616,25 @@ class DenseVector(val values: Array[Double]) extends Vector { override def hashCode(): Int = { var result: Int = 31 + size var i = 0 - val end = math.min(values.length, 16) - while (i < end) { + val end = values.length + var nnz = 0 + while (i < end && nnz < Vectors.MAX_HASH_NNZ) { val v = values(i) if (v != 0.0) { result = 31 * result + i val bits = java.lang.Double.doubleToLongBits(values(i)) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } i += 1 } result } + @Since("1.4.0") override def numActives: Int = size + @Since("1.4.0") override def numNonzeros: Int = { // same as values.count(_ != 0.0) but faster var nnz = 0 @@ -579,6 +646,7 @@ class DenseVector(val values: Array[Double]) extends Vector { nnz } + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros val ii = new Array[Int](nnz) @@ -594,6 +662,7 @@ class DenseVector(val values: Array[Double]) extends Vector { new SparseVector(size, ii, vv) } + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -611,10 +680,19 @@ class DenseVector(val values: Array[Double]) extends Vector { maxIdx } } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) + compact(render(jValue)) + } } +@Since("1.3.0") object DenseVector { + /** Extracts the value array from a dense vector. */ + @Since("1.3.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } @@ -625,11 +703,12 @@ object DenseVector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class SparseVector( - override val size: Int, - val indices: Array[Int], - val values: Array[Double]) extends Vector { +class SparseVector @Since("1.0.0") ( + @Since("1.0.0") override val size: Int, + @Since("1.0.0") val indices: Array[Int], + @Since("1.0.0") val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + @@ -640,6 +719,7 @@ class SparseVector( override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" + @Since("1.0.0") override def toArray: Array[Double] = { val data = new Array[Double](size) var i = 0 @@ -651,13 +731,15 @@ class SparseVector( data } + @Since("1.1.0") override def copy: SparseVector = { new SparseVector(size, indices.clone(), values.clone()) } private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) - private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + @Since("1.6.0") + override def foreachActive(f: (Int, Double) => Unit): Unit = { var i = 0 val localValuesSize = values.length val localIndices = indices @@ -672,27 +754,26 @@ class SparseVector( override def hashCode(): Int = { var result: Int = 31 + size val end = values.length - var continue = true var k = 0 - while ((k < end) & continue) { - val i = indices(k) - if (i < 16) { - val v = values(k) - if (v != 0.0) { - result = 31 * result + i - val bits = java.lang.Double.doubleToLongBits(v) - result = 31 * result + (bits ^ (bits >>> 32)).toInt - } - } else { - continue = false + var nnz = 0 + while (k < end && nnz < Vectors.MAX_HASH_NNZ) { + val v = values(k) + if (v != 0.0) { + val i = indices(k) + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } k += 1 } result } + @Since("1.4.0") override def numActives: Int = values.length + @Since("1.4.0") override def numNonzeros: Int = { var nnz = 0 values.foreach { v => @@ -703,6 +784,7 @@ class SparseVector( nnz } + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros if (nnz == numActives) { @@ -722,6 +804,7 @@ class SparseVector( } } + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -766,9 +849,44 @@ class SparseVector( maxIdx } } + + /** + * Create a slice of this vector based on the given indices. + * @param selectedIndices Unsorted list of indices into the vector. + * This does NOT do bound checking. + * @return New SparseVector with values in the order specified by the given indices. + * + * NOTE: The API needs to be discussed before making this public. + * Also, if we have a version assuming indices are sorted, we should optimize it. + */ + private[spark] def slice(selectedIndices: Array[Int]): SparseVector = { + var currentIdx = 0 + val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx => + val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx) + val i_v = if (iIdx >= 0) { + Iterator((currentIdx, this.values(iIdx))) + } else { + Iterator() + } + currentIdx += 1 + i_v + }.unzip + new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) + } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 0) ~ + ("size" -> size) ~ + ("indices" -> indices.toSeq) ~ + ("values" -> values.toSeq) + compact(render(jValue)) + } } +@Since("1.3.0") object SparseVector { + @Since("1.3.0") def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 3323ae7b1fba..09527dcf5d9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -54,12 +54,14 @@ private[mllib] class GridPartitioner( /** * Returns the index of the partition the input coordinate belongs to. * - * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in - * multiplication. k is ignored in computing partitions. + * @param key The partition id i (calculated through this method for coordinate (i, j) in + * `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is + * the inner index used in multiplication. k is ignored in computing partitions. * @return The index of the partition, which the coordinate belongs to. */ override def getPartition(key: Any): Int = { key match { + case i: Int => i case (i: Int, j: Int) => getPartitionId(i, j) case (i: Int, j: Int, _: Int) => @@ -113,8 +115,6 @@ private[mllib] object GridPartitioner { } /** - * :: Experimental :: - * * Represents a distributed matrix in blocks of local matrices. * * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that @@ -129,11 +129,11 @@ private[mllib] object GridPartitioner { * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to * zero, the number of columns will be calculated when `numCols` is invoked. */ -@Experimental -class BlockMatrix( - val blocks: RDD[((Int, Int), Matrix)], - val rowsPerBlock: Int, - val colsPerBlock: Int, +@Since("1.3.0") +class BlockMatrix @Since("1.3.0") ( + @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)], + @Since("1.3.0") val rowsPerBlock: Int, + @Since("1.3.0") val colsPerBlock: Int, private var nRows: Long, private var nCols: Long) extends DistributedMatrix with Logging { @@ -150,6 +150,7 @@ class BlockMatrix( * @param colsPerBlock Number of columns that make up each block. The blocks forming the final * columns are not required to have the given number of columns */ + @Since("1.3.0") def this( blocks: RDD[((Int, Int), Matrix)], rowsPerBlock: Int, @@ -157,17 +158,21 @@ class BlockMatrix( this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) } + @Since("1.3.0") override def numRows(): Long = { if (nRows <= 0L) estimateDim() nRows } + @Since("1.3.0") override def numCols(): Long = { if (nCols <= 0L) estimateDim() nCols } + @Since("1.3.0") val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + @Since("1.3.0") val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt private[mllib] def createPartitioner(): GridPartitioner = @@ -193,6 +198,7 @@ class BlockMatrix( * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if * any error is found. */ + @Since("1.3.0") def validate(): Unit = { logDebug("Validating BlockMatrix...") // check if the matrix is larger than the claimed dimensions @@ -229,18 +235,21 @@ class BlockMatrix( } /** Caches the underlying RDD. */ + @Since("1.3.0") def cache(): this.type = { blocks.cache() this } /** Persists the underlying RDD with the specified storage level. */ + @Since("1.3.0") def persist(storageLevel: StorageLevel): this.type = { blocks.persist(storageLevel) this } /** Converts to CoordinateMatrix. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => val rowStart = blockRowIndex.toLong * rowsPerBlock @@ -255,6 +264,7 @@ class BlockMatrix( } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.3.0") def toIndexedRowMatrix(): IndexedRowMatrix = { require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + s"numCols: ${numCols()}") @@ -263,6 +273,7 @@ class BlockMatrix( } /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + @Since("1.3.0") def toLocalMatrix(): Matrix = { require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + s"Int.MaxValue. Currently numRows: ${numRows()}") @@ -287,8 +298,11 @@ class BlockMatrix( new DenseMatrix(m, n, values) } - /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the - * same underlying data. Is a lazy operation. */ + /** + * Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. + */ + @Since("1.3.0") def transpose: BlockMatrix = { val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => ((blockColIndex, blockRowIndex), mat.transpose) @@ -302,12 +316,14 @@ class BlockMatrix( new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) } - /** Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. - */ + /** + * Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + @Since("1.3.0") def add(other: BlockMatrix): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") @@ -335,12 +351,51 @@ class BlockMatrix( } } - /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` - * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains - * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output - * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause - * some performance issues until support for multiplying two sparse matrices is added. - */ + /** Block (i,j) --> Set of destination partitions */ + private type BlockDestinations = Map[(Int, Int), Set[Int]] + + /** + * Simulate the multiplication with just block indices in order to cut costs on communication, + * when we are actually shuffling the matrices. + * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`. + * Exposed for tests. + * + * @param other The BlockMatrix to multiply + * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B` + * @return A tuple of [[BlockDestinations]]. The first element is the Map of the set of partitions + * that we need to shuffle each blocks of `this`, and the second element is the Map for + * `other`. + */ + private[distributed] def simulateMultiply( + other: BlockMatrix, + partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { + val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached + val rightMatrix = other.blocks.keys.collect() + val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => + val rightCounterparts = rightMatrix.filter(_._1 == colIndex) + val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => + val leftCounterparts = leftMatrix.filter(_._2 == rowIndex) + val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + (leftDestinations, rightDestinations) + } + + /** + * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + * + * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added + * with each other. + */ + @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + @@ -349,33 +404,30 @@ class BlockMatrix( if (colsPerBlock == other.rowsPerBlock) { val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, math.max(blocks.partitions.length, other.blocks.partitions.length)) - // Each block of A must be multiplied with the corresponding blocks in each column of B. - // TODO: Optimize to send block to a partition once, similar to ALS + val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner) + // Each block of A must be multiplied with the corresponding blocks in the columns of B. val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } // Each block of B must be multiplied with the corresponding blocks in each row of A. val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } - val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) - .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => - if (a.size > 1 || b.size > 1) { - throw new SparkException("There are multiple MatrixBlocks with indices: " + - s"($blockRowIndex, $blockColIndex). Please remove them.") - } - if (a.nonEmpty && b.nonEmpty) { - val C = b.head match { - case dense: DenseMatrix => a.head.multiply(dense) - case sparse: SparseMatrix => a.head.multiply(sparse.toDense) - case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) => + a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) => + b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) => + val C = rightBlock match { + case dense: DenseMatrix => leftBlock.multiply(dense) + case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense) + case _ => + throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.") } - Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) - } else { - Iterator() + ((leftRowIndex, rightColIndex), C.toBreeze) } - }.reduceByKey(resultPartitioner, (a, b) => a + b) - .mapValues(Matrices.fromBreeze) + } + }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze) // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 078d1fac4444..8a70f34e70f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -19,22 +19,20 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** - * :: Experimental :: * Represents an entry in an distributed matrix. * @param i row index * @param j column index * @param value value of the entry */ -@Experimental +@Since("1.0.0") case class MatrixEntry(i: Long, j: Long, value: Double) /** - * :: Experimental :: * Represents a matrix in coordinate format. * * @param entries matrix entries @@ -43,16 +41,18 @@ case class MatrixEntry(i: Long, j: Long, value: Double) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the max column index plus one. */ -@Experimental -class CoordinateMatrix( - val entries: RDD[MatrixEntry], +@Since("1.0.0") +class CoordinateMatrix @Since("1.0.0") ( + @Since("1.0.0") val entries: RDD[MatrixEntry], private var nRows: Long, private var nCols: Long) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0L) { computeSize() @@ -61,6 +61,7 @@ class CoordinateMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { computeSize() @@ -69,11 +70,13 @@ class CoordinateMatrix( } /** Transposes this CoordinateMatrix. */ + @Since("1.3.0") def transpose(): CoordinateMatrix = { new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.0.0") def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() if (nl > Int.MaxValue) { @@ -93,11 +96,13 @@ class CoordinateMatrix( * Converts to RowMatrix, dropping row indices after grouping by row index. * The number of columns must be within the integer range. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { toIndexedRowMatrix().toRowMatrix() } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -110,6 +115,7 @@ class CoordinateMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { require(rowsPerBlock > 0, s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index a0e26ce3bc46..db3433a5e245 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -19,15 +19,20 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.annotation.Since + /** * Represents a distributively stored matrix backed by one or more RDDs. */ +@Since("1.0.0") trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ + @Since("1.0.0") def numRows(): Long /** Gets or computes the number of columns. */ + @Since("1.0.0") def numCols(): Long /** Collects data and assembles a local dense breeze matrix (for test only). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 1c33b43ea7a8..976299124ced 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -19,20 +19,18 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition /** - * :: Experimental :: * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */ -@Experimental +@Since("1.0.0") case class IndexedRow(index: Long, vector: Vector) /** - * :: Experimental :: * Represents a row-oriented [[org.apache.spark.mllib.linalg.distributed.DistributedMatrix]] with * indexed rows. * @@ -42,15 +40,17 @@ case class IndexedRow(index: Long, vector: Vector) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ -@Experimental -class IndexedRowMatrix( - val rows: RDD[IndexedRow], +@Since("1.0.0") +class IndexedRowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, private var nCols: Int) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0) + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { // Calling `first` will throw an exception if `rows` is empty. @@ -59,6 +59,7 @@ class IndexedRowMatrix( nCols } + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { // Reduce will throw an exception if `rows` is empty. @@ -67,15 +68,30 @@ class IndexedRowMatrix( nRows } + + /** + * Compute all cosine similarities between columns of this matrix using the brute-force + * approach of computing normalized dot products. + * + * @return An n x n sparse upper-triangular matrix of cosine similarities between + * columns of this matrix. + */ + @Since("1.6.0") + def columnSimilarities(): CoordinateMatrix = { + toRowMatrix().columnSimilarities() + } + /** * Drops row indices and converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { new RowMatrix(rows.map(_.vector), 0L, nCols) } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -88,6 +104,7 @@ class IndexedRowMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { // TODO: This implementation may be optimized toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) @@ -97,6 +114,7 @@ class IndexedRowMatrix( * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entries = rows.flatMap { row => val rowIndex = row.index @@ -133,6 +151,7 @@ class IndexedRowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V) */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -159,6 +178,7 @@ class IndexedRowMatrix( * @param B a local matrix whose number of rows must match the number of columns of this matrix * @return an IndexedRowMatrix representing the product, which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { val mat = toRowMatrix().multiply(B) val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => @@ -170,6 +190,7 @@ class IndexedRowMatrix( /** * Computes the Gramian matrix `A^T A`. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { toRowMatrix().computeGramianMatrix() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index bfc90c9ef852..2018a678688e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -24,11 +24,9 @@ import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -36,7 +34,6 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * Represents a row-oriented distributed Matrix with no meaningful row indices. * * @param rows rows stored as an RDD[Vector] @@ -45,16 +42,18 @@ import org.apache.spark.storage.StorageLevel * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ -@Experimental -class RowMatrix( - val rows: RDD[Vector], +@Since("1.0.0") +class RowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[Vector], private var nRows: Long, private var nCols: Int) extends DistributedMatrix with Logging { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[Vector]) = this(rows, 0L, 0) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { try { @@ -70,6 +69,7 @@ class RowMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { nRows = rows.count() @@ -106,8 +106,10 @@ class RowMatrix( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -118,7 +120,7 @@ class RowMatrix( // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { - RowMatrix.dspr(1.0, v, U.data) + BLAS.spr(1.0, v, U.data) U }, combOp = (U1, U2) => U1 += U2) @@ -146,7 +148,8 @@ class RowMatrix( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -178,6 +181,7 @@ class RowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -315,9 +319,11 @@ class RowMatrix( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ + @Since("1.0.0") def computeCovariance(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -348,9 +354,11 @@ class RowMatrix( var alpha = 0.0 while (i < n) { alpha = m / m1 * mean(i) - j = 0 + j = i while (j < n) { - G(i, j) = G(i, j) / m1 - alpha * mean(j) + val Gij = G(i, j) / m1 - alpha * mean(j) + G(i, j) = Gij + G(j, i) = Gij j += 1 } i += 1 @@ -360,7 +368,8 @@ class RowMatrix( } /** - * Computes the top k principal components. + * Computes the top k principal components and a vector of proportions of + * variance explained by each principal component. * Rows correspond to observations and columns correspond to variables. * The principal components are stored a local matrix of size n-by-k. * Each column corresponds for one principal component, @@ -368,27 +377,49 @@ class RowMatrix( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. - * @return a matrix of size n-by-k, whose columns are principal components + * @return a matrix of size n-by-k, whose columns are principal components, and + * a vector of values which indicate how much variance each principal component + * explains */ - def computePrincipalComponents(k: Int): Matrix = { + @Since("1.6.0") + def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] - val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov) + val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov) + + val eigenSum = s.data.sum + val explainedVariance = s.data.map(_ / eigenSum) if (k == n) { - Matrices.dense(n, k, u.data) + (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance)) } else { - Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)) + (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)), + Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k))) } } + /** + * Computes the top k principal components only. + * + * @param k number of top principal components. + * @return a matrix of size n-by-k, whose columns are principal components + * @see computePrincipalComponentsAndExplainedVariance + */ + @Since("1.0.0") + def computePrincipalComponents(k: Int): Matrix = { + computePrincipalComponentsAndExplainedVariance(k)._1 + } + /** * Computes column-wise summary statistics. */ + @Since("1.0.0") def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), @@ -404,6 +435,7 @@ class RowMatrix( * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product, * which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt val k = B.numCols @@ -436,6 +468,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities between * columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(): CoordinateMatrix = { columnSimilarities(0.0) } @@ -479,6 +512,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities * between columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(threshold: Double): CoordinateMatrix = { require(threshold >= 0, s"Threshold cannot be negative: $threshold") @@ -507,6 +541,7 @@ class RowMatrix( * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. */ + @Since("1.5.0") def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them @@ -656,45 +691,9 @@ class RowMatrix( } } -@Experimental +@Since("1.0.0") object RowMatrix { - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param U the upper triangular part of the matrix packed in an array (column major) - */ - private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { - // TODO: Find a better home (breeze?) for this method. - val n = v.size - v match { - case DenseVector(values) => - blas.dspr("U", n, alpha, values, 1, U) - case SparseVector(size, indices, values) => - val nnz = indices.length - var colStartIdx = 0 - var prevCol = 0 - var col = 0 - var j = 0 - var i = 0 - var av = 0.0 - while (j < nnz) { - col = indices(j) - // Skip empty columns. - colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 - col = indices(j) - av = alpha * values(j) - i = 0 - while (i <= j) { - U(colStartIdx + indices(i)) += av * values(i) - i += 1 - } - j += 1 - prevCol = col - } - } - } - /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 8f0d1e4aa010..37bb6f6097f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -81,11 +81,13 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the convergence tolerance. Default 0.001 * convergenceTol is a condition which decides iteration termination. * The end of iteration is decided based on below logic. - * - If the norm of the new solution vector is >1, the diff of solution vectors - * is compared to relative tolerance which means normalizing by the norm of - * the new solution vector. - * - If the norm of the new solution vector is <=1, the diff of solution vectors - * is compared to absolute tolerance which is not normalizing. + * + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { @@ -235,7 +237,7 @@ object GradientDescent extends Logging { if (miniBatchSize > 0) { /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * lossSum is computed using the weights from the previous iteration * and regVal is the regularization value computed in the previous iteration as well. */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) @@ -264,6 +266,9 @@ object GradientDescent extends Logging { } + /** + * Alias of [[runMiniBatchSGD]] with convergenceTol set to default value of 0.001. + */ def runMiniBatchSGD( data: RDD[(Double, Vector)], gradient: Gradient, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 5e882d4ebb10..274ac7c99553 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,7 +23,7 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -33,6 +33,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory * developed by the Data Mining Group (www.dmg.org). */ @DeveloperApi +@Since("1.4.0") trait PMMLExportable { /** @@ -48,6 +49,7 @@ trait PMMLExportable { * Export the model to a local file in PMML format */ @Experimental + @Since("1.4.0") def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } @@ -57,6 +59,7 @@ trait PMMLExportable { * Export the model to a directory on a distributed file system in PMML format */ @Experimental + @Since("1.4.0") def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) @@ -67,6 +70,7 @@ trait PMMLExportable { * Export the model to the OutputStream in PMML format */ @Experimental + @Since("1.4.0") def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } @@ -76,6 +80,7 @@ trait PMMLExportable { * Export the model to a String in PMML format */ @Experimental + @Since("1.4.0") def toPMML(): String = { val writer = new StringWriter toPMML(new StreamResult(writer)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 622b53a252ac..7abb1bf7ce96 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -45,7 +45,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( val fields = new SArray[FieldName](model.weights.size) val dataDictionary = new DataDictionary val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { if (threshold <= 0) { @@ -56,35 +56,35 @@ private[mllib] class BinaryClassificationPMMLModelExport( interceptNO = -math.log(1 / threshold - 1) } } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0") val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) + .setFunctionName(MiningFunctionType.CLASSIFICATION) + .setMiningSchema(miningSchema) + .setModelName(description) + .setNormalizationMethod(normalizationMethod) + .addRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field val targetField = FieldName.create("target") dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 1874786af000..4d951d2973a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -45,31 +45,31 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.REGRESSION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withRegressionTables(regressionTable) + .setFunctionName(MiningFunctionType.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(description) + .addRegressionTables(regressionTable) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 069e7afc9fca..b5b824bb9c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -42,42 +42,42 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val comparisonMeasure = new ComparisonMeasure() - .withKind(ComparisonMeasure.Kind.DISTANCE) - .withMeasure(new SquaredEuclidean()) + .setKind(ComparisonMeasure.Kind.DISTANCE) + .setMeasure(new SquaredEuclidean()) val clusteringModel = new ClusteringModel() - .withModelName("k-means") - .withMiningSchema(miningSchema) - .withComparisonMeasure(comparisonMeasure) - .withFunctionName(MiningFunctionType.CLUSTERING) - .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) - .withNumberOfClusters(model.clusterCenters.length) + .setModelName("k-means") + .setMiningSchema(miningSchema) + .setComparisonMeasure(comparisonMeasure) + .setFunctionName(MiningFunctionType.CLUSTERING) + .setModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .setNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + clusteringModel.addClusteringFields( + new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF)) } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) - for (i <- 0 until model.clusterCenters.length) { + for (i <- model.clusterCenters.indices) { val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Array.Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) + .setName("cluster_" + i) + .setArray(new org.dmg.pmml.Array() + .setType(Array.Type.REAL) + .setN(clusterCenter.size) + .setValue(model.clusterCenters(i).toArray.mkString(" "))) // we don't have the size of the single cluster but only the centroids (withValue) // .withSize(value) - clusteringModel.withClusters(cluster) + clusteringModel.addClusters(cluster) } pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) + pmml.addModels(clusteringModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index c5fdecd3ca17..426bb818c926 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -30,18 +30,14 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML - - setHeader(pmml) - - private def setHeader(pmml: PMML): Unit = { + val pmml: PMML = { val version = getClass.getPackage.getImplementationVersion - val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) - pmml.setHeader(header) + .setApplication(app) + .setTimestamp(timestamp) + new PMML("4.2", header, null) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9349ecaa13f5..9eab7efc160d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.random -import org.apache.commons.math3.distribution.{ExponentialDistribution, - GammaDistribution, LogNormalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** @@ -28,17 +27,20 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} * Trait for random data generators that generate i.i.d. data. */ @DeveloperApi +@Since("1.1.0") trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** * Returns an i.i.d. sample as a generic type from an underlying distribution. */ + @Since("1.1.0") def nextValue(): T /** * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ + @Since("1.1.0") def copy(): RandomDataGenerator[T] } @@ -47,17 +49,21 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @DeveloperApi +@Since("1.1.0") class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextDouble() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): UniformGenerator = new UniformGenerator() } @@ -66,17 +72,21 @@ class UniformGenerator extends RandomDataGenerator[Double] { * Generates i.i.d. samples from the standard normal distribution. */ @DeveloperApi +@Since("1.1.0") class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextGaussian() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } @@ -87,16 +97,21 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { * @param mean mean for the Poisson distribution. */ @DeveloperApi -class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.1.0") +class PoissonGenerator @Since("1.1.0") ( + @Since("1.1.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new PoissonDistribution(mean) + @Since("1.1.0") override def nextValue(): Double = rng.sample() + @Since("1.1.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.1.0") override def copy(): PoissonGenerator = new PoissonGenerator(mean) } @@ -107,16 +122,21 @@ class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { * @param mean mean for the exponential distribution. */ @DeveloperApi -class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class ExponentialGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new ExponentialDistribution(mean) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): ExponentialGenerator = new ExponentialGenerator(mean) } @@ -128,16 +148,22 @@ class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] * @param scale scale for the gamma distribution */ @DeveloperApi -class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class GammaGenerator @Since("1.3.0") ( + @Since("1.3.0") val shape: Double, + @Since("1.3.0") val scale: Double) extends RandomDataGenerator[Double] { private val rng = new GammaDistribution(shape, scale) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): GammaGenerator = new GammaGenerator(shape, scale) } @@ -150,15 +176,45 @@ class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGen * @param std standard deviation for the log normal distribution */ @DeveloperApi -class LogNormalGenerator(val mean: Double, val std: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class LogNormalGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double, + @Since("1.3.0") val std: Double) extends RandomDataGenerator[Double] { private val rng = new LogNormalDistribution(mean, std) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } + +/** + * :: DeveloperApi :: + * Generates i.i.d. samples from the Weibull distribution with the + * given shape and scale parameter. + * + * @param alpha shape parameter for the Weibull distribution. + * @param beta scale parameter for the Weibull distribution. + */ +@DeveloperApi +class WeibullGenerator( + val alpha: Double, + val beta: Double) extends RandomDataGenerator[Double] { + + private val rng = new WeibullDistribution(alpha, beta) + + override def nextValue(): Double = rng.sample() + + override def setSeed(seed: Long): Unit = { + rng.reseedRandomGenerator(seed) + } + + override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 174d5e0f6c9f..b0a716936ae6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,18 +20,18 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ -@Experimental +@Since("1.1.0") object RandomRDDs { /** @@ -46,6 +46,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformRDD( sc: SparkContext, size: Long, @@ -58,6 +59,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ + @Since("1.1.0") def uniformJavaRDD( jsc: JavaSparkContext, size: Long, @@ -69,6 +71,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } @@ -76,6 +79,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } @@ -92,6 +96,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ + @Since("1.1.0") def normalRDD( sc: SparkContext, size: Long, @@ -104,6 +109,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalRDD]]. */ + @Since("1.1.0") def normalJavaRDD( jsc: JavaSparkContext, size: Long, @@ -115,6 +121,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } @@ -122,6 +129,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } @@ -137,6 +145,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonRDD( sc: SparkContext, mean: Double, @@ -150,6 +159,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -162,6 +172,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -173,6 +184,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) } @@ -188,6 +200,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def exponentialRDD( sc: SparkContext, mean: Double, @@ -201,6 +214,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialRDD]]. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -213,6 +227,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -224,6 +239,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def exponentialJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(exponentialRDD(jsc.sc, mean, size)) } @@ -240,6 +256,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def gammaRDD( sc: SparkContext, shape: Double, @@ -254,6 +271,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaRDD]]. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -267,6 +285,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -279,11 +298,12 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaRDD( - jsc: JavaSparkContext, - shape: Double, - scale: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + shape: Double, + scale: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(gammaRDD(jsc.sc, shape, scale, size)) } @@ -299,6 +319,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def logNormalRDD( sc: SparkContext, mean: Double, @@ -313,6 +334,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalRDD]]. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -326,6 +348,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -338,11 +361,12 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( - jsc: JavaSparkContext, - mean: Double, - std: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + mean: Double, + std: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(logNormalRDD(jsc.sc, mean, std, size)) } @@ -356,9 +380,10 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomRDD[T: ClassTag]( sc: SparkContext, generator: RandomDataGenerator[T], @@ -368,6 +393,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** @@ -381,6 +455,7 @@ object RandomRDDs { * @param seed Seed for the RNG that generates the seed for the generator in each partition. * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformVectorRDD( sc: SparkContext, numRows: Long, @@ -394,6 +469,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -406,6 +482,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -417,6 +494,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -435,6 +513,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ + @Since("1.1.0") def normalVectorRDD( sc: SparkContext, numRows: Long, @@ -448,6 +527,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -460,6 +540,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -471,6 +552,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -491,6 +573,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ + @Since("1.3.0") def logNormalVectorRDD( sc: SparkContext, mean: Double, @@ -507,6 +590,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalVectorRDD]]. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -521,6 +605,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -535,6 +620,7 @@ object RandomRDDs { * [[RandomRDDs#logNormalJavaVectorRDD]] with the default number of partitions and * the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -556,6 +642,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonVectorRDD( sc: SparkContext, mean: Double, @@ -570,6 +657,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -583,6 +671,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -595,6 +684,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -615,6 +705,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def exponentialVectorRDD( sc: SparkContext, mean: Double, @@ -630,6 +721,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialVectorRDD]]. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -643,6 +735,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -656,6 +749,7 @@ object RandomRDDs { * [[RandomRDDs#exponentialJavaVectorRDD]] with the default number of partitions * and the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -678,6 +772,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def gammaVectorRDD( sc: SparkContext, shape: Double, @@ -693,6 +788,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaVectorRDD]]. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -707,6 +803,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -720,6 +817,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -744,6 +842,7 @@ object RandomRDDs { * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, @@ -754,6 +853,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 78172843be56..19a047ded257 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -37,15 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Array[T]] = { + def sliding(windowSize: Int, step: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") - if (windowSize == 1) { + if (windowSize == 1 && step == 1) { self.map(Array(_)) } else { - new SlidingRDD[T](self, windowSize) + new SlidingRDD[T](self, windowSize, step) } } + /** + * [[sliding(Int, Int)*]] with step = 1. + */ + def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1) + /** * Reduces the elements of this RDD in a multi-level tree pattern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index 910eff9540a4..f8cea7ecea6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -35,11 +35,11 @@ private[mllib] class RandomRDDPartition[T](override val index: Int, } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: RandomDataGenerator[T], - @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { + @transient private val rng: RandomDataGenerator[T], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") @@ -56,12 +56,12 @@ private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, } } -private[mllib] class RandomVectorRDD(@transient sc: SparkContext, +private[mllib] class RandomVectorRDD(sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: RandomDataGenerator[Double], - @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { + @transient private val rng: RandomDataGenerator[Double], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 1facf83d806d..ead8db634499 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.{TaskContext, Partition} import org.apache.spark.rdd.RDD private[mllib] -class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int) extends Partition with Serializable { override val index: Int = idx } /** - * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to @@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] * * @param parent the parent RDD * @param windowSize the window size, must be greater than 1 + * @param step step size for windows * - * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]] + * @see [[scala.collection.IterableLike.sliding(Int, Int)*]] */ private[mllib] -class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int) extends RDD[Array[T]](parent) { - require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1), + "Window size and step must be greater than 0, " + + s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.") override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) - .sliding(windowSize) + .drop(part.offset) + .sliding(windowSize, step) .withPartial(false) .map(_.toArray) } @@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int override def getPartitions: Array[Partition] = { val parentPartitions = parent.partitions - val n = parentPartitions.size + val n = parentPartitions.length if (n == 0) { Array.empty } else if (n == 1) { - Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0)) } else { - val n1 = n - 1 val w1 = windowSize - 1 - // Get the first w1 items of each partition, starting from the second partition. - val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) - val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + // Get partition sizes and first w1 elements. + val (sizes, heads) = parent.mapPartitions { iter => + val w1Array = iter.take(w1).toArray + Iterator.single((w1Array.length + iter.length, w1Array)) + }.collect().unzip + val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]] var i = 0 + var cumSize = 0 var partitionIndex = 0 - while (i < n1) { - var j = i - val tail = mutable.ListBuffer[T]() - // Keep appending to the current tail until appended a head of size w1. - while (j < n1 && nextHeads(j).size < w1) { - tail ++= nextHeads(j) - j += 1 + while (i < n) { + val mod = cumSize % step + val offset = if (mod == 0) 0 else step - mod + val size = sizes(i) + if (offset < size) { + val tail = mutable.ListBuffer.empty[T] + // Keep appending to the current tail until it has w1 elements. + var j = i + 1 + while (j < n && tail.length < w1) { + tail ++= heads(j).take(w1 - tail.length) + j += 1 + } + if (sizes(i) + tail.length >= offset + windowSize) { + partitions += + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + partitionIndex += 1 + } } - if (j < n1) { - tail ++= nextHeads(j) - j += 1 - } - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) - partitionIndex += 1 - // Skip appended heads. - i = j - } - // If the head of last partition has size w1, we also need to add this partition. - if (nextHeads.last.size == w1) { - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + cumSize += size + i += 1 } partitions.toArray } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 56c549ef99cb..33aaf853e599 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD @@ -26,9 +26,12 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. - * @since 0.8.0 */ -case class Rating(user: Int, product: Int, rating: Double) +@Since("0.8.0") +case class Rating @Since("0.8.0") ( + @Since("0.8.0") user: Int, + @Since("0.8.0") product: Int, + @Since("0.8.0") rating: Double) /** * Alternating Least Squares matrix factorization. @@ -59,6 +62,7 @@ case class Rating(user: Int, product: Int, rating: Double) * indicated user * preferences rather than explicit ratings given to items. */ +@Since("0.8.0") class ALS private ( private var numUserBlocks: Int, private var numProductBlocks: Int, @@ -74,6 +78,7 @@ class ALS private ( * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. */ + @Since("0.8.0") def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) /** If true, do alternating nonnegative least squares. */ @@ -90,6 +95,7 @@ class ALS private ( * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ + @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks @@ -99,6 +105,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ + @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this @@ -107,30 +114,35 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ + @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ + @Since("0.8.0") def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ + @Since("0.8.0") def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ + @Since("0.8.0") def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ + @Since("0.8.1") def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this @@ -139,12 +151,14 @@ class ALS private ( /** * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ + @Since("0.8.1") def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ + @Since("1.0.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -154,6 +168,7 @@ class ALS private ( * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ + @Since("1.1.0") def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this @@ -166,6 +181,7 @@ class ALS private ( * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. */ @DeveloperApi + @Since("1.1.0") def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { require(storageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") @@ -181,6 +197,7 @@ class ALS private ( * at the cost of speed. */ @DeveloperApi + @Since("1.3.0") def setFinalRDDStorageLevel(storageLevel: StorageLevel): this.type = { this.finalRDDStorageLevel = storageLevel this @@ -194,6 +211,7 @@ class ALS private ( * this setting is ignored. */ @DeveloperApi + @Since("1.4.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -203,6 +221,7 @@ class ALS private ( * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ + @Since("0.8.0") def run(ratings: RDD[Rating]): MatrixFactorizationModel = { val sc = ratings.context @@ -250,13 +269,14 @@ class ALS private ( /** * Java-friendly version of [[ALS.run]]. */ + @Since("1.3.0") def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) } /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. - * @since 0.8.0 */ +@Since("0.8.0") object ALS { /** * Train a matrix factorization model given an RDD of ratings given by users to some products, @@ -271,8 +291,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed - * @since 0.9.1 */ + @Since("0.9.1") def train( ratings: RDD[Rating], rank: Int, @@ -296,8 +316,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @since 0.8.0 */ + @Since("0.8.0") def train( ratings: RDD[Rating], rank: Int, @@ -319,8 +339,8 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { train(ratings, rank, iterations, lambda, -1) @@ -336,8 +356,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { train(ratings, rank, iterations, 0.01, -1) @@ -357,8 +377,8 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -384,8 +404,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -409,8 +429,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, lambda, -1, alpha) @@ -427,8 +447,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 261ca9cef0c5..0dc40483dd0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -30,6 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ @@ -49,12 +50,12 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. - * @since 0.8.0 */ -class MatrixFactorizationModel( - val rank: Int, - val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) +@Since("0.8.0") +class MatrixFactorizationModel @Since("0.8.0") ( + @Since("0.8.0") val rank: Int, + @Since("0.8.0") val userFeatures: RDD[(Int, Array[Double])], + @Since("0.8.0") val productFeatures: RDD[(Int, Array[Double])]) extends Saveable with Serializable with Logging { require(rank > 0) @@ -74,9 +75,8 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. - * @since 0.8.0 - */ + /** Predict the rating of one user for one product. */ + @Since("0.8.0") def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -114,8 +114,8 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. - * @since 0.9.0 */ + @Since("0.9.0") def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. // So if the usersProducts given for prediction contains only few products or @@ -146,8 +146,8 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. - * @since 1.2.0 */ + @Since("1.2.0") def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() } @@ -162,8 +162,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) @@ -179,8 +179,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) @@ -199,8 +199,8 @@ class MatrixFactorizationModel( * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -212,8 +212,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API - * @since 1.4.0 */ + @Since("1.4.0") def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { case (user, top) => @@ -230,8 +230,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API - * @since 1.4.0 */ + @Since("1.4.0") def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { case (product, top) => @@ -241,9 +241,7 @@ class MatrixFactorizationModel( } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -326,8 +324,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -355,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -366,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 6709bd79bc82..8f657bfb9c73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD @@ -34,9 +34,13 @@ import org.apache.spark.storage.StorageLevel * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") @DeveloperApi -abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) +abstract class GeneralizedLinearModel @Since("1.0.0") ( + @Since("1.0.0") val weights: Vector, + @Since("0.8.0") val intercept: Double) extends Serializable { /** @@ -53,7 +57,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. @@ -71,7 +77,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } @@ -88,14 +96,20 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * :: DeveloperApi :: * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. + * */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - /** The optimizer to solve the problem. */ + /** + * The optimizer to solve the problem. + * + */ + @Since("0.8.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ @@ -130,7 +144,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The dimension of training features. + * */ + @Since("1.4.0") def getNumFeatures: Int = this.numFeatures /** @@ -153,13 +169,17 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Get if the algorithm uses addIntercept + * */ + @Since("1.4.0") def isAddIntercept: Boolean = this.addIntercept /** * Set if the algorithm should add an intercept. Default false. * We set the default to false because adding the intercept will cause memory allocation. + * */ + @Since("0.8.0") def setIntercept(addIntercept: Boolean): this.type = { this.addIntercept = addIntercept this @@ -167,7 +187,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Set if the algorithm should validate data before training. Default true. + * */ + @Since("0.8.0") def setValidateData(validateData: Boolean): this.type = { this.validateData = validateData this @@ -176,7 +198,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. + * */ + @Since("0.8.0") def run(input: RDD[LabeledPoint]): M = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() @@ -208,7 +232,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. + * */ + @Since("1.0.0") def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { if (numFeatures < 0) { @@ -333,6 +359,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] + " parent RDDs are also uncached.") } + // Unpersist cached data + if (data.getStorageLevel != StorageLevel.NONE) { + data.unpersist(false) + } + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index f3b46c75c05f..f235089873ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -29,7 +29,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -37,8 +37,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext /** - * :: Experimental :: - * * Regression model for isotonic regression. * * @param boundaries Array of boundaries for which predictions are known. @@ -46,12 +44,13 @@ import org.apache.spark.sql.SQLContext * @param predictions Array of predictions associated to the boundaries at the same index. * Results of isotonic regression and therefore monotone. * @param isotonic indicates whether this is isotonic or antitonic. + * */ -@Experimental -class IsotonicRegressionModel ( - val boundaries: Array[Double], - val predictions: Array[Double], - val isotonic: Boolean) extends Serializable with Saveable { +@Since("1.3.0") +class IsotonicRegressionModel @Since("1.3.0") ( + @Since("1.3.0") val boundaries: Array[Double], + @Since("1.3.0") val predictions: Array[Double], + @Since("1.3.0") val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -59,7 +58,10 @@ class IsotonicRegressionModel ( assertOrdered(boundaries) assertOrdered(predictions)(predictionOrd) - /** A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. */ + /** + * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. + */ + @Since("1.4.0") def this(boundaries: java.lang.Iterable[Double], predictions: java.lang.Iterable[Double], isotonic: java.lang.Boolean) = { @@ -83,7 +85,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: RDD[Double]): RDD[Double] = { testData.map(predict) } @@ -94,7 +98,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) } @@ -114,7 +120,9 @@ class IsotonicRegressionModel ( * 3) If testData falls between two values in boundary array then prediction is treated * as piecewise linear function and interpolated value is returned. In case there are * multiple values with the same boundary then the same rules as in 2) are used. + * */ + @Since("1.3.0") def predict(testData: Double): Double = { def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { @@ -148,6 +156,7 @@ class IsotonicRegressionModel ( /** A convenient method for boundaries called by the Python API. */ private[mllib] def predictionVector: Vector = Vectors.dense(predictions) + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) } @@ -155,6 +164,7 @@ class IsotonicRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { import org.apache.spark.mllib.util.Loader._ @@ -175,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -188,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) @@ -200,6 +210,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } } + @Since("1.4.0") override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) @@ -219,8 +230,6 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } /** - * :: Experimental :: - * * Isotonic regression. * Currently implemented using parallelized pool adjacent violators algorithm. * Only univariate (single feature) algorithm supported. @@ -238,7 +247,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { * * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ -@Experimental +@Since("1.3.0") class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { /** @@ -246,6 +255,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * * @return New instance of IsotonicRegression. */ + @Since("1.3.0") def this() = this(true) /** @@ -254,6 +264,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. * @return This instance of IsotonicRegression. */ + @Since("1.3.0") def setIsotonic(isotonic: Boolean): this.type = { this.isotonic = isotonic this @@ -269,6 +280,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { val preprocessedInput = if (isotonic) { input @@ -294,6 +306,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index d5fea822ad77..c284ad232537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -29,8 +30,11 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@Since("0.8.0") @BeanInfo -case class LabeledPoint(label: Double, features: Vector) { +case class LabeledPoint @Since("1.0.0") ( + @Since("0.8.0") label: Double, + @Since("1.0.0") features: Vector) { override def toString: String = { s"($label,$features)" } @@ -38,12 +42,16 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ +@Since("1.1.0") object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ + @Since("1.1.0") def parse(s: String): LabeledPoint = { if (s.startsWith("(")) { NumericParser.parse(s) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 4f482384f0f3..a9aba173fa0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -30,10 +31,12 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class LassoModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class LassoModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -44,6 +47,7 @@ class LassoModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -51,8 +55,10 @@ class LassoModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object LassoModel extends Loader[LassoModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): LassoModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -78,6 +84,7 @@ object LassoModel extends Loader[LassoModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LassoWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -87,6 +94,7 @@ class LassoWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new L1Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -97,6 +105,7 @@ class LassoWithSGD private ( * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -106,7 +115,9 @@ class LassoWithSGD private ( /** * Top-level methods for calling Lasso. + * */ +@Since("0.8.0") object LassoWithSGD { /** @@ -123,7 +134,9 @@ object LassoWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -146,7 +159,9 @@ object LassoWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -167,7 +182,9 @@ object LassoWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -185,7 +202,9 @@ object LassoWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 9453c4f66c21..4996ace5df85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -30,10 +31,12 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class LinearRegressionModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class LinearRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -44,6 +47,7 @@ class LinearRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -51,8 +55,10 @@ class LinearRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object LinearRegressionModel extends Loader[LinearRegressionModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): LinearRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -79,6 +85,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -87,6 +94,7 @@ class LinearRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SimpleUpdater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -96,6 +104,7 @@ class LinearRegressionWithSGD private[mllib] ( * Construct a LinearRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -105,7 +114,9 @@ class LinearRegressionWithSGD private[mllib] ( /** * Top-level methods for calling LinearRegression. + * */ +@Since("0.8.0") object LinearRegressionWithSGD { /** @@ -121,7 +132,9 @@ object LinearRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -142,7 +155,9 @@ object LinearRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -161,7 +176,9 @@ object LinearRegressionWithSGD { * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -178,7 +195,9 @@ object LinearRegressionWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LinearRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 214ac4d0ed7d..a95a54225a08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,19 +19,21 @@ package org.apache.spark.mllib.regression import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD -@Experimental +@Since("0.8.0") trait RegressionModel extends Serializable { /** * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -39,14 +41,18 @@ trait RegressionModel extends Serializable { * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 7d28ffad45c9..0a44ff559d55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -31,10 +32,12 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class RidgeRegressionModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class RidgeRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -45,6 +48,7 @@ class RidgeRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -52,8 +56,10 @@ class RidgeRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): RidgeRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -79,6 +85,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class RidgeRegressionWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -88,7 +95,7 @@ class RidgeRegressionWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new SquaredL2Updater() - + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -99,6 +106,7 @@ class RidgeRegressionWithSGD private ( * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -108,7 +116,9 @@ class RidgeRegressionWithSGD private ( /** * Top-level methods for calling RidgeRegression. + * */ +@Since("0.8.0") object RidgeRegressionWithSGD { /** @@ -124,7 +134,9 @@ object RidgeRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -146,7 +158,9 @@ object RidgeRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -166,7 +180,9 @@ object RidgeRegressionWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -183,7 +199,9 @@ object RidgeRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 141052ba813e..73948b2d9851 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.regression import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream @@ -53,7 +53,9 @@ import org.apache.spark.streaming.dstream.DStream * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. * + * */ +@Since("1.1.0") @DeveloperApi abstract class StreamingLinearAlgorithm[ M <: GeneralizedLinearModel, @@ -65,7 +67,11 @@ abstract class StreamingLinearAlgorithm[ /** The algorithm to use for updating. */ protected val algorithm: A - /** Return the latest model. */ + /** + * Return the latest model. + * + */ + @Since("1.1.0") def latestModel(): M = { model.get } @@ -78,6 +84,7 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data */ + @Since("1.1.0") def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") @@ -95,7 +102,10 @@ abstract class StreamingLinearAlgorithm[ } } - /** Java-friendly version of `trainOn`. */ + /** + * Java-friendly version of `trainOn`. + */ + @Since("1.3.0") def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) /** @@ -103,7 +113,9 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing feature vectors * @return DStream containing predictions + * */ + @Since("1.1.0") def predictOn(data: DStream[Vector]): DStream[Double] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") @@ -111,7 +123,11 @@ abstract class StreamingLinearAlgorithm[ data.map{x => model.get.predict(x)} } - /** Java-friendly version of `predictOn`. */ + /** + * Java-friendly version of `predictOn`. + * + */ + @Since("1.3.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } @@ -121,7 +137,9 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing feature vectors * @tparam K key type * @return DStream containing the input keys and the predictions as values + * */ + @Since("1.1.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") @@ -130,7 +148,11 @@ abstract class StreamingLinearAlgorithm[ } - /** Java-friendly version of `predictOnValues`. */ + /** + * Java-friendly version of `predictOnValues`. + * + */ + @Since("1.3.0") def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { implicit val tag = fakeClassTag[K] JavaPairDStream.fromPairDStream( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index c6d04464a12b..fe2a46b9eecc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** - * :: Experimental :: * Train or predict a linear regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) @@ -39,9 +38,8 @@ import org.apache.spark.mllib.linalg.Vector * .setNumIterations(10) * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) - * */ -@Experimental +@Since("1.1.0") class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -55,40 +53,56 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.1.0") def this() = this(0.1, 50, 1.0) + @Since("1.1.0") val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) protected var model: Option[LinearRegressionModel] = None - /** Set the step size for gradient descent. Default: 0.1. */ + /** + * Set the step size for gradient descent. Default: 0.1. + */ + @Since("1.1.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } - /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + /** + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } - /** Set the fraction of each batch to use for updates. Default: 1.0. */ + /** + * Set the fraction of each batch to use for updates. Default: 1.0. + */ + @Since("1.1.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } - /** Set the initial weights. */ + /** + * Set the initial weights. + */ + @Since("1.1.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } - /** Set the convergence tolerance. */ + /** + * Set the convergence tolerance. Default: 0.001. + */ + @Since("1.5.0") def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) this } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a570263..02af281fb726 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 93a6753efd4d..f253963270bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.stat import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Kernel density estimation. Given a sample from a population, estimate its probability density * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported. * @@ -37,9 +36,8 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} - * @since 1.4.0 */ -@Experimental +@Since("1.4.0") class KernelDensity extends Serializable { import KernelDensity._ @@ -52,8 +50,8 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). - * @since 1.4.0 */ + @Since("1.4.0") def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") this.bandwidth = bandwidth @@ -62,8 +60,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: RDD[Double]): this.type = { this.sample = sample this @@ -71,8 +69,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] this @@ -80,8 +78,8 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. - * @since 1.4.0 */ + @Since("1.4.0") def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample val bandwidth = this.bandwidth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 62da9f2ef22a..201333c3690d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,24 +17,27 @@ package org.apache.spark.mllib.stat -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector * format in a online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. * - * A numerically stable algorithm is implemented to compute sample mean and variance: + * A numerically stable algorithm is implemented to compute the mean and variance of instances: * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. - * @since 1.1.0 + * + * For weighted instances, the unbiased estimation of variance is defined by the reliability + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]. */ +@Since("1.1.0") @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 + private var weightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 private var nnz: Array[Double] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -53,12 +58,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ - def add(sample: Vector): this.type = { + @Since("1.1.0") + def add(sample: Vector): this.type = add(sample, 1.0) + + private[spark] def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") + if (weight == 0.0) return this + if (n == 0) { - require(sample.size > 0, s"Vector should have dimension larger than zero.") - n = sample.size + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size currMean = Array.ofDim[Double](n) currM2n = Array.ofDim[Double](n) @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = Array.fill[Double](n)(Double.MaxValue) } - require(n == sample.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.size}.") + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") val localCurrMean = currMean val localCurrM2n = currM2n @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localNnz = nnz val localCurrMax = currMax val localCurrMin = currMin - sample.foreachActive { (index, value) => + instance.foreachActive { (index, value) => if (value != 0.0) { if (localCurrMax(index) < value) { localCurrMax(index) = value @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) - localCurrM2n(index) += (value - localCurrMean(index)) * diff - localCurrM2(index) += value * value - localCurrL1(index) += math.abs(value) + localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + localCurrM2(index) += weight * value * value + localCurrL1(index) += weight * math.abs(value) - localNnz(index) += 1.0 + localNnz(index) += weight } } + weightSum += weight + weightSquareSum += weight * weight totalCnt += 1 this } @@ -109,13 +121,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ + @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.weightSum != 0.0 && other.weightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt + weightSum += other.weightSum + weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { val thisNnz = nnz(i) @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S nnz(i) = totalNnz i += 1 } - } else if (totalCnt == 0 && other.totalCnt != 0) { + } else if (weightSum == 0.0 && other.weightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt + this.weightSum = other.weightSum + this.weightSquareSum = other.weightSquareSum this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -153,29 +169,33 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Sample mean of each dimension. + * */ + @Since("1.1.0") override def mean: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / totalCnt) + realMean(i) = currMean(i) * (nnz(i) / weightSum) i += 1 } Vectors.dense(realMean) } /** - * @since 1.1.0 + * Unbiased estimate of sample variance of each dimension. + * */ + @Since("1.1.0") override def variance: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = totalCnt - 1.0 + val denominator = weightSum - (weightSquareSum / weightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -183,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * + (weightSum - nnz(i)) / weightSum) / denominator i += 1 } } @@ -193,52 +212,62 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Sample size. + * */ + @Since("1.1.0") override def count: Long = totalCnt /** - * @since 1.1.0 + * Number of nonzero elements in each dimension. + * */ + @Since("1.1.0") override def numNonzeros: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } /** - * @since 1.1.0 + * Maximum value of each dimension. + * */ + @Since("1.1.0") override def max: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) } /** - * @since 1.1.0 + * Minimum value of each dimension. + * */ + @Since("1.1.0") override def min: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) } /** - * @since 1.2.0 + * L2 (Euclidian) norm of each dimension. + * */ + @Since("1.2.0") override def normL2: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -252,10 +281,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.2.0 + * L1 norm of each dimension. + * */ + @Since("1.2.0") override def normL1: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 3bb49f12289e..39a16fb743d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -17,59 +17,60 @@ package org.apache.spark.mllib.stat +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. - * @since 1.0.0 */ +@Since("1.0.0") trait MultivariateStatisticalSummary { /** * Sample mean vector. - * @since 1.0.0 */ + @Since("1.0.0") def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. - * @since 1.0.0 */ + @Since("1.0.0") def variance: Vector /** * Sample size. - * @since 1.0.0 */ + @Since("1.0.0") def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. - * @since 1.0.0 */ + @Since("1.0.0") def numNonzeros: Vector /** * Maximum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def max: Vector /** * Minimum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def min: Vector /** * Euclidean magnitude of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL2: Vector /** * L1 norm of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f84502919e38..bcb33a7a0467 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.annotation.Since +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -30,11 +30,9 @@ import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovS import org.apache.spark.rdd.RDD /** - * :: Experimental :: * API for statistical functions in MLlib. - * @since 1.1.0 */ -@Experimental +@Since("1.1.0") object Statistics { /** @@ -42,8 +40,8 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. - * @since 1.1.0 */ + @Since("1.1.0") def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } @@ -54,8 +52,8 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** @@ -71,8 +69,8 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** @@ -85,14 +83,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -109,14 +107,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -133,8 +131,8 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } @@ -148,8 +146,8 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** @@ -159,8 +157,8 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** @@ -172,12 +170,16 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + /** Java-friendly version of [[chiSqTest()]] */ + @Since("1.5.0") + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + /** * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a * continuous distribution. By comparing the largest difference between the empirical cumulative @@ -191,6 +193,7 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double) : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, cdf) @@ -207,9 +210,20 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") @varargs def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*) : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @Since("1.5.0") + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: Double*): KolmogorovSmirnovTestResult = { + kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 9aa7763d7890..0724af93088c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.distribution import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} -import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils @@ -32,12 +32,12 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution - * @since 1.3.0 */ +@Since("1.3.0") @DeveloperApi -class MultivariateGaussian ( - val mu: Vector, - val sigma: Matrix) extends Serializable { +class MultivariateGaussian @Since("1.3.0") ( + @Since("1.3.0") val mu: Vector, + @Since("1.3.0") val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") @@ -56,21 +56,21 @@ class MultivariateGaussian ( /** * Compute distribution dependent constants: - * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t + * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } @@ -104,11 +104,11 @@ class MultivariateGaussian ( * * sigma = U * D * U.t * inv(Sigma) = U * inv(D) * U.t - * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) + * = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t) * * and thus * - * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ + * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^ * * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered @@ -130,7 +130,7 @@ class MultivariateGaussian ( // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) + (pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => throw new IllegalArgumentException("Covariance matrix has no non-zero singular values") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 000000000000..e990fe0768bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import scala.beans.BeanInfo + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.streaming.api.java.JavaDStream +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Class that represents the group and value of a sample. + * + * @param isExperiment if the sample is of the experiment group. + * @param value numeric value of the observation. + */ +@Since("1.6.0") +@BeanInfo +case class BinarySample @Since("1.6.0") ( + @Since("1.6.0") isExperiment: Boolean, + @Since("1.6.0") value: Double) { + override def toString: String = { + s"($isExperiment, $value)" + } +} + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of BinarySample(key,value) pairs where the key denotes group membership + * (true = experiment, false = control) and the value is the numerical metric to + * test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** + * Register a [[JavaDStream]] of values for significance testing. + * + * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes + * group (true = experiment, false = control) and the value is the numerical metric + * to test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = { + JavaDStream.fromDStream(registerStream(data.dstream)) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[BinarySample]): DStream[BinarySample] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = { + val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value)) + if (this.windowSize == 0) { + categoryValuePair.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + categoryValuePair + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 000000000000..911b4b923735 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch" -> WelchTTest, + "student" -> StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index f44be1370669..8a29fd39a910 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -17,36 +17,39 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** - * :: Experimental :: * Trait for hypothesis test results. * @tparam DF Return type of `degreesOfFreedom`. */ -@Experimental +@Since("1.1.0") trait TestResult[DF] { /** * The probability of obtaining a test statistic result at least as extreme as the one that was * actually observed, assuming that the null hypothesis is true. */ + @Since("1.1.0") def pValue: Double /** * Returns the degree(s) of freedom of the hypothesis test. * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. */ + @Since("1.1.0") def degreesOfFreedom: DF /** * Test statistic. */ + @Since("1.1.0") def statistic: Double /** * Null hypothesis of the test. */ + @Since("1.1.0") def nullHypothesis: String /** @@ -74,15 +77,14 @@ trait TestResult[DF] { } /** - * :: Experimental :: * Object containing the test results for the chi-squared hypothesis test. */ -@Experimental +@Since("1.1.0") class ChiSqTestResult private[stat] (override val pValue: Double, - override val degreesOfFreedom: Int, - override val statistic: Double, - val method: String, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.1.0") override val degreesOfFreedom: Int, + @Since("1.1.0") override val statistic: Double, + @Since("1.1.0") val method: String, + @Since("1.1.0") override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { "Chi squared test summary:\n" + @@ -96,14 +98,38 @@ class ChiSqTestResult private[stat] (override val pValue: Double, * Object containing the test results for the Kolmogorov-Smirnov test. */ @Experimental +@Since("1.5.0") class KolmogorovSmirnovTestResult private[stat] ( - override val pValue: Double, - override val statistic: Double, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val pValue: Double, + @Since("1.5.0") override val statistic: Double, + @Since("1.5.0") override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val degreesOfFreedom = 0 override def toString: String = { "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index cecd1fed896d..af1f7e74c004 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,10 +19,9 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo @@ -37,15 +36,15 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom /** - * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ -@Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +@Since("1.0.0") +class DecisionTree @Since("1.0.0") (private val strategy: Strategy) + extends Serializable with Logging { strategy.assertValid() @@ -54,6 +53,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) @@ -62,6 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } +@Since("1.0.0") object DecisionTree extends Serializable with Logging { /** @@ -79,7 +80,8 @@ object DecisionTree extends Serializable with Logging { * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction - */ + */ + @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).run(input) } @@ -101,6 +103,7 @@ object DecisionTree extends Serializable with Logging { * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -128,6 +131,7 @@ object DecisionTree extends Serializable with Logging { * @param numClasses number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -161,6 +165,7 @@ object DecisionTree extends Serializable with Logging { * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -193,6 +198,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -208,6 +214,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] */ + @Since("1.1.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -237,6 +244,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -250,6 +258,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] */ + @Since("1.1.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -631,8 +640,8 @@ object DecisionTree extends Serializable with Logging { val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) } // find best split for each node @@ -964,8 +973,8 @@ object DecisionTree extends Serializable with Logging { val numFeatures = metadata.numFeatures // Sample the input only if there are continuous features. - val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) - val sampledInput = if (hasContinuousFeatures) { + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) val fraction = if (requiredSamples < metadata.numExamples) { @@ -974,81 +983,14 @@ object DecisionTree extends Serializable with Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()) } else { - new Array[LabeledPoint](0) + input.sparkContext.emptyRDD[LabeledPoint] } metadata.quantileStrategy match { case Sort => - val splits = new Array[Array[Split]](numFeatures) - val bins = new Array[Array[Bin]](numFeatures) - - // Find all splits. - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isContinuous(featureIndex)) { - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) - val featureSplits = findSplitsForContinuousFeature(featureSamples, - metadata, featureIndex) - - val numSplits = featureSplits.length - val numBins = numSplits + 1 - logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") - splits(featureIndex) = new Array[Split](numSplits) - bins(featureIndex) = new Array[Bin](numBins) - - var splitIndex = 0 - while (splitIndex < numSplits) { - val threshold = featureSplits(splitIndex) - splits(featureIndex)(splitIndex) = - new Split(featureIndex, threshold, Continuous, List()) - splitIndex += 1 - } - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - - splitIndex = 1 - while (splitIndex < numSplits) { - bins(featureIndex)(splitIndex) = - new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), - Continuous, Double.MinValue) - splitIndex += 1 - } - bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } else { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) - // Categorical feature - val featureArity = metadata.featureArity(featureIndex) - if (metadata.isUnordered(featureIndex)) { - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - splits(featureIndex) = new Array[Split](numSplits) - var splitIndex = 0 - while (splitIndex < numSplits) { - val categories: List[Double] = - extractMultiClassCategories(splitIndex + 1, featureArity) - splits(featureIndex)(splitIndex) = - new Split(featureIndex, Double.MinValue, Categorical, categories) - splitIndex += 1 - } - } else { - // Ordered features - // Bins correspond to feature values, so we do not need to compute splits or bins - // beforehand. Splits are constructed as needed during training. - splits(featureIndex) = new Array[Split](0) - } - // For ordered features, bins correspond to feature values. - // For unordered categorical features, there is no need to construct the bins. - // since there is a one-to-one correspondence between the splits and the bins. - bins(featureIndex) = new Array[Bin](0) - } - featureIndex += 1 - } - (splits, bins) + findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => @@ -1056,6 +998,82 @@ object DecisionTree extends Serializable with Logging { } } + private def findSplitsBinsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = { + def findSplits( + featureIndex: Int, + featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = { + val splits = { + val featureSplits = findSplitsForContinuousFeature( + featureSamples.toArray, + metadata, + featureIndex) + logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}") + + featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil)) + } + + val bins = { + val lowSplit = new DummyLowSplit(featureIndex, Continuous) + val highSplit = new DummyHighSplit(featureIndex, Continuous) + + // tack the dummy splits on either side of the computed splits + val allSplits = lowSplit +: splits.toSeq :+ highSplit + + // slide across the split points pairwise to allocate the bins + allSplits.sliding(2).map { + case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue) + }.toArray + } + + (featureIndex, (splits, bins)) + } + + val continuousSplits = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (k, v) => findSplits(k, v) } + .collectAsMap() + } + + val numFeatures = metadata.numFeatures + val (splits, bins) = Range(0, numFeatures).unzip { + case i if metadata.isContinuous(i) => + val (split, bin) = continuousSplits(i) + metadata.setNumSplits(i, split.length) + (split, bin) + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(i) + val split = Range(0, metadata.numSplits(i)).map { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new Split(i, Double.MinValue, Categorical, categories) + } + + // For unordered categorical features, there is no need to construct the bins. + // since there is a one-to-one correspondence between the splits and the bins. + (split.toArray, Array.empty[Bin]) + + case i if metadata.isCategorical(i) => + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + (Array.empty[Split], Array.empty[Bin]) + } + + (splits.toArray, bins.toArray) + } + /** * Nested method to extract list of eligible categories given an index. It extracts the * position of ones in a binary representation of the input. If binary @@ -1119,7 +1137,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splitsBuilder = ArrayBuilder.make[Double] + val splitsBuilder = Array.newBuilder[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1151,8 +1169,8 @@ object DecisionTree extends Serializable with Logging { assert(splits.length > 0, s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + " Please remove this feature and then try again.") - // set number of splits accordingly - metadata.setNumSplits(featureIndex, splits.length) + + // the split metadata must be updated on the driver splits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 9ce6faa137c4..729a21157482 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint @@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * A class that implements * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] * for regression and binary classification. @@ -49,8 +48,8 @@ import org.apache.spark.storage.StorageLevel * * @param boostingStrategy Parameters for the gradient boosting algorithm. */ -@Experimental -class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) +@Since("1.2.0") +class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { /** @@ -58,6 +57,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { @@ -75,6 +75,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. */ + @Since("1.2.0") def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } @@ -89,6 +90,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.4.0") def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -112,6 +114,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. */ + @Since("1.4.0") def runWithValidation( input: JavaRDD[LabeledPoint], validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -119,6 +122,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) } } +@Since("1.2.0") object GradientBoostedTrees extends Logging { /** @@ -130,6 +134,7 @@ object GradientBoostedTrees extends Logging { * @param boostingStrategy Configuration options for the boosting algorithm. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -139,6 +144,7 @@ object GradientBoostedTrees extends Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] */ + @Since("1.2.0") def train( input: JavaRDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -254,7 +260,8 @@ object GradientBoostedTrees extends Logging { validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() - if (bestValidateError - currentValidateError < validationTol) { + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { doneLearning = true } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 069959976a18..a684cdd18c2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.SamplingUtils /** - * :: Experimental :: * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] * learning algorithm for classification and regression. * It supports both continuous and categorical features. @@ -66,7 +65,6 @@ import org.apache.spark.util.random.SamplingUtils * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. */ -@Experimental private class RandomForest ( private val strategy: Strategy, private val numTrees: Int, @@ -260,6 +258,7 @@ private class RandomForest ( } +@Since("1.2.0") object RandomForest extends Serializable with Logging { /** @@ -277,6 +276,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, @@ -314,6 +314,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -333,6 +334,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] */ + @Since("1.2.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -363,6 +365,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, @@ -399,6 +402,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -417,6 +421,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] */ + @Since("1.2.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -434,6 +439,7 @@ object RandomForest extends Serializable with Logging { /** * List of supported feature subset sampling strategies. */ + @Since("1.2.0") val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "sqrt", "log2", "onethird") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index b6099259971b..853c7319ec44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -17,15 +17,18 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to select the algorithm for the decision tree */ +@Since("1.0.0") @Experimental object Algo extends Enumeration { + @Since("1.0.0") type Algo = Value + @Since("1.0.0") val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 9fd30c9b5631..d2513a9d5c5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** - * :: Experimental :: * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. * * @param treeStrategy Parameters for the tree algorithm. We support regression and binary @@ -34,20 +33,27 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param validationTol Useful when runWithValidation is used. If the error rate on the - * validation input between two iterations is less than the validationTol - * then stop. Ignored when + * @param validationTol validationTol is a condition which decides iteration termination when + * runWithValidation is used. + * The end of iteration is decided based on below logic: + * If the current loss on the validation set is > 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is <= 0.01, the diff + * of validation error is compared to absolute tolerance which is + * validationTol * 0.01. + * Ignored when * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ -@Experimental -case class BoostingStrategy( +@Since("1.2.0") +case class BoostingStrategy @Since("1.4.0") ( // Required boosting parameters - @BeanProperty var treeStrategy: Strategy, - @BeanProperty var loss: Loss, + @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, + @Since("1.2.0") @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var validationTol: Double = 1e-5) extends Serializable { + @Since("1.2.0") @BeanProperty var numIterations: Int = 100, + @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, + @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable { /** * Check validity of parameters. @@ -70,7 +76,7 @@ case class BoostingStrategy( } } -@Experimental +@Since("1.2.0") object BoostingStrategy { /** @@ -78,6 +84,7 @@ object BoostingStrategy { * @param algo Learning goal. Supported: "Classification" or "Regression" * @return Configuration for boosting algorithm */ + @Since("1.2.0") def defaultParams(algo: String): BoostingStrategy = { defaultParams(Algo.fromString(algo)) } @@ -89,8 +96,9 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ + @Since("1.3.0") def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = Strategy.defaultStategy(algo) + val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 algo match { case Algo.Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index f4c877232750..1470295d8a93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -17,14 +17,15 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since /** - * :: Experimental :: * Enum to describe whether a feature is "continuous" or "categorical" */ -@Experimental +@Since("1.0.0") object FeatureType extends Enumeration { + @Since("1.0.0") type FeatureType = Value + @Since("1.0.0") val Continuous, Categorical = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 7da976e55a72..1c16f136eb3e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -17,14 +17,15 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since /** - * :: Experimental :: * Enum for selecting the quantile calculation strategy */ -@Experimental +@Since("1.0.0") object QuantileStrategy extends Enumeration { + @Since("1.0.0") type QuantileStrategy = Value + @Since("1.0.0") val Sort, MinMax, ApproxHist = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index ada227c200a7..372d6617a401 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,13 +20,12 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** - * :: Experimental :: * Stores all the configuration options for tree construction * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], @@ -67,26 +66,32 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * the checkpoint directory is not set in * [[org.apache.spark.SparkContext]], this setting is ignored. */ -@Experimental -class Strategy ( - @BeanProperty var algo: Algo, - @BeanProperty var impurity: Impurity, - @BeanProperty var maxDepth: Int, - @BeanProperty var numClasses: Int = 2, - @BeanProperty var maxBins: Int = 32, - @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - @BeanProperty var minInstancesPerNode: Int = 1, - @BeanProperty var minInfoGain: Double = 0.0, - @BeanProperty var maxMemoryInMB: Int = 256, - @BeanProperty var subsamplingRate: Double = 1, - @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointInterval: Int = 10) extends Serializable { +@Since("1.0.0") +class Strategy @Since("1.3.0") ( + @Since("1.0.0") @BeanProperty var algo: Algo, + @Since("1.0.0") @BeanProperty var impurity: Impurity, + @Since("1.0.0") @BeanProperty var maxDepth: Int, + @Since("1.2.0") @BeanProperty var numClasses: Int = 2, + @Since("1.0.0") @BeanProperty var maxBins: Int = 32, + @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, + @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, + @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, + @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + /** + */ + @Since("1.2.0") def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 } + /** + */ + @Since("1.2.0") def isMulticlassWithCategoricalFeatures: Boolean = { isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } @@ -94,6 +99,7 @@ class Strategy ( /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ + @Since("1.1.0") def this( algo: Algo, impurity: Impurity, @@ -108,6 +114,7 @@ class Strategy ( /** * Sets Algorithm using a String. */ + @Since("1.2.0") def setAlgo(algo: String): Unit = algo match { case "Classification" => setAlgo(Classification) case "Regression" => setAlgo(Regression) @@ -116,6 +123,7 @@ class Strategy ( /** * Sets categoricalFeaturesInfo using a Java Map. */ + @Since("1.2.0") def setCategoricalFeaturesInfo( categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { this.categoricalFeaturesInfo = @@ -148,11 +156,6 @@ class Strategy ( s" Valid values are integers >= 0.") require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + s" Valid values are integers >= 2.") - categoricalFeaturesInfo.foreach { case (feature, arity) => - require(arity >= 2, - s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + - s" feature $feature has $arity categories. The number of categories should be >= 2.") - } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, @@ -162,7 +165,10 @@ class Strategy ( s"$subsamplingRate") } - /** Returns a shallow copy of this instance. */ + /** + * Returns a shallow copy of this instance. + */ + @Since("1.2.0") def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, @@ -170,22 +176,24 @@ class Strategy ( } } -@Experimental +@Since("1.2.0") object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ + @Since("1.2.0") def defaultStrategy(algo: String): Strategy = { - defaultStategy(Algo.fromString(algo)) + defaultStrategy(Algo.fromString(algo)) } /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo Algo.Classification or Algo.Regression */ - def defaultStategy(algo: Algo): Strategy = algo match { + @Since("1.3.0") + def defaultStrategy(algo: Algo): Strategy = algo match { case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) @@ -193,4 +201,9 @@ object Strategy { new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } + + @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") + @Since("1.2.0") + def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 9fe264656ede..21ee49c45788 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b5..1c611976a930 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -108,21 +108,21 @@ private[spark] class NodeIdCache( prevNodeIdsForInstances = nodeIdsForInstances nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - dataPoint => { + case (point, node) => { var treeId = 0 while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) if (nodeIdUpdater != null) { val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = dataPoint._1.datum.binnedFeatures, + binnedFeatures = point.datum.binnedFeatures, bins = bins) - dataPoint._2(treeId) = newNodeIndex + node(treeId) = newNodeIndex } treeId += 1 } - dataPoint._2 + node } } @@ -138,7 +138,7 @@ private[spark] class NodeIdCache( while (checkpointQueue.size > 1 && canDelete) { // We can delete the oldest checkpoint iff // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile != None) { + if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, @@ -159,13 +159,17 @@ private[spark] class NodeIdCache( * Call this after training is finished to delete any remaining checkpoints. */ def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { + while (checkpointQueue.nonEmpty) { val old = checkpointQueue.dequeue() - if (old.getCheckpointFile != None) { + for (checkpointFile <- old.getCheckpointFile) { val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) + fs.delete(new Path(checkpointFile), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index aac84243d5ce..70afaa162b2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -19,12 +19,9 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable.{HashMap => MutableHashMap} -import org.apache.spark.annotation.Experimental - /** * Time tracker implementation which holds labeled timers. */ -@Experimental private[spark] class TimeTracker extends Serializable { private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0768204c3391..73df6b054a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. */ +@Since("1.0.0") @Experimental object Entropy extends Impurity { @@ -36,6 +37,7 @@ object Entropy extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -63,6 +65,7 @@ object Entropy extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") @@ -71,6 +74,7 @@ object Entropy extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d0077db6832e..f21845b21a80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] * during binary classification. */ +@Since("1.0.0") @Experimental object Gini extends Impurity { @@ -35,6 +36,7 @@ object Gini extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -59,6 +61,7 @@ object Gini extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") @@ -67,6 +70,7 @@ object Gini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 86cee7e430b0..4637dcceea7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -26,6 +26,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] * (b) calculating impurity values from sufficient statistics. */ +@Since("1.0.0") @Experimental trait Impurity extends Serializable { @@ -36,6 +37,7 @@ trait Impurity extends Serializable { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -47,6 +49,7 @@ trait Impurity extends Serializable { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 04d0cd24e663..a74197278d6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating variance during regression */ +@Since("1.0.0") @Experimental object Variance extends Impurity { @@ -33,6 +34,7 @@ object Variance extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") @@ -45,6 +47,7 @@ object Variance extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { if (count == 0) { @@ -58,6 +61,7 @@ object Variance extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.0.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 2bdef73c4a8f..bab7b8c6cadf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * |y - F(x)| * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object AbsoluteError extends Loss { @@ -41,6 +42,7 @@ object AbsoluteError extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { if (label - prediction < 0) 1.0 else -1.0 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 778c24526de7..b2b4594712f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * 2 log(1 + exp(-2 y F(x))) * where y is a label in {-1, 1} and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object LogLoss extends Loss { @@ -43,6 +44,7 @@ object LogLoss extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 64ffccbce073..687cde325ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. */ +@Since("1.2.0") @DeveloperApi trait Loss extends Serializable { @@ -36,6 +37,7 @@ trait Loss extends Serializable { * @param label true label. * @return Loss gradient. */ + @Since("1.2.0") def gradient(prediction: Double, label: Double): Double /** @@ -46,6 +48,7 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data */ + @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map(point => computeError(model.predict(point.features), point.label)).mean() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala index 42c9ead9884b..2b112fbe1220 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.tree.loss +import org.apache.spark.annotation.Since + +@Since("1.2.0") object Losses { + @Since("1.2.0") def fromString(name: String): Loss = name match { case "leastSquaresError" => SquaredError case "leastAbsoluteError" => AbsoluteError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 011a5d57422f..3f7d3d38be16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * (y - F(x))**2 * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object SquaredError extends Loss { @@ -41,6 +42,7 @@ object SquaredError extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 2.0 * (label - prediction) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index f2c78bbabff0..89c470d57343 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} @@ -35,14 +35,15 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.util.Utils /** - * :: Experimental :: * Decision tree model for classification or regression. * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ -@Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { +@Since("1.0.0") +class DecisionTreeModel @Since("1.0.0") ( + @Since("1.0.0") val topNode: Node, + @Since("1.0.0") val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -50,6 +51,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features array representing a single data point * @return Double prediction from the trained model */ + @Since("1.0.0") def predict(features: Vector): Double = { topNode.predict(features) } @@ -60,6 +62,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features RDD representing data points to be predicted * @return RDD of predictions for each of the given data points */ + @Since("1.0.0") def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } @@ -70,6 +73,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features JavaRDD representing data points to be predicted * @return JavaRDD of predictions for each of the given data points */ + @Since("1.2.0") def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { predict(features.rdd) } @@ -77,6 +81,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Get number of nodes in tree, including leaf nodes. */ + @Since("1.1.0") def numNodes: Int = { 1 + topNode.numDescendants } @@ -85,6 +90,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Get depth of tree. * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. */ + @Since("1.1.0") def depth: Int = { topNode.subtreeDepth } @@ -104,11 +110,18 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Print the full model to a string. */ + @Since("1.2.0") def toDebugString: String = { val header = toString + "\n" header + topNode.subtreeToString(2) } + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } @@ -116,6 +129,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable override protected def formatVersion: String = DecisionTreeModel.formatVersion } +@Since("1.3.0") object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { private[spark] def formatVersion: String = "1.0" @@ -187,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -228,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -297,6 +311,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } } + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): DecisionTreeModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 508bf9c1bdb4..091a0462c204 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator * @param leftPredict left node predict * @param rightPredict right node predict */ +@Since("1.0.0") @DeveloperApi class InformationGainStats( val gain: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index a6d1398fc267..ea6e5aa5d94e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vector @@ -39,16 +39,17 @@ import org.apache.spark.mllib.linalg.Vector * @param rightNode right child * @param stats information gain stats */ +@Since("1.0.0") @DeveloperApi -class Node ( - val id: Int, - var predict: Predict, - var impurity: Double, - var isLeaf: Boolean, - var split: Option[Split], - var leftNode: Option[Node], - var rightNode: Option[Node], - var stats: Option[InformationGainStats]) extends Serializable with Logging { +class Node @Since("1.2.0") ( + @Since("1.0.0") val id: Int, + @Since("1.0.0") var predict: Predict, + @Since("1.2.0") var impurity: Double, + @Since("1.0.0") var isLeaf: Boolean, + @Since("1.0.0") var split: Option[Split], + @Since("1.0.0") var leftNode: Option[Node], + @Since("1.0.0") var rightNode: Option[Node], + @Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + @@ -59,6 +60,7 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @Since("1.0.0") @deprecated("build should no longer be used since trees are constructed on-the-fly in training", "1.2.0") def build(nodes: Array[Node]): Unit = { @@ -80,6 +82,7 @@ class Node ( * @param features feature value * @return predicted value */ + @Since("1.1.0") def predict(features: Vector) : Double = { if (isLeaf) { predict.predict diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 5cbe7c280dbe..06ceff19d863 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,17 +17,18 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ +@Since("1.2.0") @DeveloperApi -class Predict( - val predict: Double, - val prob: Double = 0.0) extends Serializable { +class Predict @Since("1.2.0") ( + @Since("1.2.0") val predict: Double, + @Since("1.2.0") val prob: Double = 0.0) extends Serializable { override def toString: String = s"$predict (prob = $prob)" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index be6c9b3de547..b85a66c05a81 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -31,12 +31,13 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * @param featureType type of feature -- categorical or continuous * @param categories Split left if categorical feature value is in this set, else right. */ +@Since("1.0.0") @DeveloperApi case class Split( - feature: Int, - threshold: Double, - featureType: FeatureType, - categories: List[Double]) { + @Since("1.0.0") feature: Int, + @Since("1.0.0") threshold: Double, + @Since("1.0.0") featureType: FeatureType, + @Since("1.0.0") categories: List[Double]) { override def toString: String = { s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 905c5fb42bd4..feabcee24fa2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -38,22 +38,29 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils - /** - * :: Experimental :: * Represents a random forest model. * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles */ -@Experimental -class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) +@Since("1.2.0") +class RandomForestModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { require(trees.forall(_.algo == algo)) + /** + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, RandomForestModel.SaveLoadV1_0.thisClassName) @@ -62,10 +69,18 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis override protected def formatVersion: String = RandomForestModel.formatVersion } +@Since("1.3.0") object RandomForestModel extends Loader[RandomForestModel] { private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -90,23 +105,28 @@ object RandomForestModel extends Loader[RandomForestModel] { } /** - * :: Experimental :: * Represents a gradient boosted trees model. * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles * @param treeWeights tree ensemble weights */ -@Experimental -class GradientBoostedTreesModel( - override val algo: Algo, - override val trees: Array[DecisionTreeModel], - override val treeWeights: Array[Double]) +@Since("1.2.0") +class GradientBoostedTreesModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel], + @Since("1.2.0") override val treeWeights: Array[Double]) extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { require(trees.length == treeWeights.length) + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) @@ -119,6 +139,7 @@ class GradientBoostedTreesModel( * @return an array with index i having the losses or errors for the ensemble * containing the first i+1 trees */ + @Since("1.4.0") def evaluateEachIteration( data: RDD[LabeledPoint], loss: Loss): Array[Double] = { @@ -159,9 +180,13 @@ class GradientBoostedTreesModel( override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion } +/** + */ +@Since("1.3.0") object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** + * :: DeveloperApi :: * Compute the initial predictions and errors for a dataset for the first * iteration of gradient boosting. * @param data: training data. @@ -171,6 +196,8 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @return a RDD with each element being a zip of the prediction and error * corresponding to every sample. */ + @Since("1.4.0") + @DeveloperApi def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -184,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { } /** + * :: DeveloperApi :: * Update a zipped predictionError RDD * (as obtained with computeInitialPredictionAndError) * @param data: training data. @@ -194,6 +222,8 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @return a RDD with each element being a zip of the prediction and error * corresponding to each sample. */ + @Since("1.4.0") + @DeveloperApi def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], @@ -213,6 +243,12 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + /** + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -376,7 +412,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -436,7 +472,7 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index be335a1aca58..dffe6e78939e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * A collection of methods used to validate data before applying ML algorithms. */ @DeveloperApi +@Since("0.8.0") object DataValidators extends Logging { /** @@ -34,6 +35,7 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ + @Since("1.0.0") val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { @@ -48,6 +50,7 @@ object DataValidators extends Logging { * * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. */ + @Since("1.3.0") def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index e6bcff48b022..00fd1606a369 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.rdd.RDD /** @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * cluster with scale 1 around each center. */ @DeveloperApi +@Since("0.8.0") object KMeansDataGenerator { /** @@ -42,6 +43,7 @@ object KMeansDataGenerator { * @param r Scaling factor for the distribution of the initial centers * @param numPartitions Number of partitions of the generated RDD; default 2 */ + @Since("0.8.0") def generateKMeansRDD( sc: SparkContext, numPoints: Int, @@ -62,6 +64,7 @@ object KMeansDataGenerator { } } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 6) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 87eeb5db05d2..094528e2ece0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.mllib.linalg.{BLAS, Vectors} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -35,6 +35,7 @@ import org.apache.spark.mllib.regression.LabeledPoint * response variable `Y`. */ @DeveloperApi +@Since("0.8.0") object LinearDataGenerator { /** @@ -46,13 +47,14 @@ object LinearDataGenerator { * @param seed Random seed * @return Java List of input. */ + @Since("0.8.0") def generateLinearInputAsList( intercept: Double, weights: Array[Double], nPoints: Int, seed: Int, eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + generateLinearInput(intercept, weights, nPoints, seed, eps).asJava } /** @@ -68,19 +70,18 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { - generateLinearInput(intercept, weights, - Array.fill[Double](weights.length)(0.0), - Array.fill[Double](weights.length)(1.0 / 3.0), - nPoints, seed, eps)} + generateLinearInput(intercept, weights, Array.fill[Double](weights.length)(0.0), + Array.fill[Double](weights.length)(1.0 / 3.0), nPoints, seed, eps) + } /** - * * @param intercept Data intercept * @param weights Weights to be applied. * @param xMean the mean of the generated features. Lots of time, if the features are not properly @@ -92,6 +93,7 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], @@ -100,24 +102,58 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps, 0.0) + } + + + /** + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @param sparsity The ratio of zero elements. If it is 0.0, LabeledPoints with + * DenseVector is returned. + * @return Seq of input. + */ + @Since("1.6.0") + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double, + sparsity: Double): Seq[LabeledPoint] = { + require(0.0 <= sparsity && sparsity <= 1.0) val rnd = new Random(seed) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble())) - - x.foreach { v => - var i = 0 - val len = v.length - while (i < len) { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - i += 1 + def rndElement(i: Int) = {(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)} + + if (sparsity == 0.0) { + (0 until nPoints).map { _ => + val features = Vectors.dense(weights.indices.map { rndElement(_) }.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } else { + (0 until nPoints).map { _ => + val indices = weights.indices.filter { _ => rnd.nextDouble() <= sparsity} + val values = indices.map { rndElement(_) } + val features = Vectors.sparse(weights.length, indices.toArray, values.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() + // Return LabeledPoints with SparseVector + LabeledPoint(label, features) } } - - val y = x.map { xi => - blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() - } - y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } /** @@ -132,6 +168,7 @@ object LinearDataGenerator { * * @return RDD of LabeledPoint containing sample data. */ + @Since("0.8.0") def generateLinearRDD( sc: SparkContext, nexamples: Int, @@ -151,6 +188,7 @@ object LinearDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index c09cbe69bb97..33477ee20ebb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.Vectors * with probability `probOne` and scales features for positive examples by `eps`. */ @DeveloperApi +@Since("0.8.0") object LogisticRegressionDataGenerator { /** @@ -43,6 +44,7 @@ object LogisticRegressionDataGenerator { * @param nparts Number of partitions of the generated RDD. Default value is 2. * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. */ + @Since("0.8.0") def generateLogisticRDD( sc: SparkContext, nexamples: Int, @@ -62,6 +64,7 @@ object LogisticRegressionDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length != 5) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 16f430599a51..906bd30563bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD @@ -52,7 +52,9 @@ import org.apache.spark.rdd.RDD * testSampFact (Double) Percentage of training data to use as test data. */ @DeveloperApi +@Since("0.8.0") object MFDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84c..4c9151f0cb4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,9 +19,7 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.Since import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD @@ -30,12 +28,11 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. */ +@Since("0.8.0") object MLUtils { private[mllib] lazy val EPSILON = { @@ -65,6 +62,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, @@ -114,6 +112,7 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -127,12 +126,14 @@ object MLUtils { * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -141,6 +142,7 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -152,6 +154,7 @@ object MLUtils { * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ + @Since("1.0.0") def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -162,6 +165,7 @@ object MLUtils { * * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]] */ + @Since("1.0.0") def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => @@ -182,12 +186,14 @@ object MLUtils { * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -198,6 +204,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -205,6 +212,7 @@ object MLUtils { * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -221,6 +229,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => @@ -242,6 +251,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) @@ -249,13 +259,20 @@ object MLUtils { } /** - * :: Experimental :: * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. */ - @Experimental + @Since("1.0.0") def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + kFold(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFold()]] taking a Long seed. + */ + @Since("2.0.0") + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, @@ -269,6 +286,7 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ + @Since("1.0.0") def appendBias(vector: Vector): Vector = { vector match { case dv: DenseVector => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index ad20b7694a77..cde597939617 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -21,11 +21,11 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -33,8 +33,10 @@ import org.apache.spark.mllib.regression.LabeledPoint * for the features and adds Gaussian noise with weight 0.1 to generate labels. */ @DeveloperApi +@Since("0.8.0") object SVMDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 30d642c754b7..4d71d534a077 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -24,7 +24,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * This should be inherited by the class which implements model instances. */ @DeveloperApi +@Since("1.3.0") trait Saveable { /** @@ -50,6 +51,7 @@ trait Saveable { * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. */ + @Since("1.3.0") def save(sc: SparkContext, path: String): Unit /** Current version of model save/load format. */ @@ -64,6 +66,7 @@ trait Saveable { * This should be inherited by an object paired with the model class. */ @DeveloperApi +@Since("1.3.0") trait Loader[M <: Saveable] { /** @@ -75,6 +78,7 @@ trait Loader[M <: Saveable] { * @param path Path specifying the directory to which the model was saved. * @return Model instance */ + @Since("1.3.0") def load(sc: SparkContext, path: String): M } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index f75e024a713e..fd22eb6dca01 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -22,6 +22,7 @@ import java.util.List; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -63,16 +64,16 @@ public void tearDown() { @Test public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); - assert(lr.getLabelCol().equals("label")); + Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults - assert(model.getThreshold() == 0.5); - assert(model.getFeaturesCol().equals("features")); - assert(model.getPredictionCol().equals("prediction")); - assert(model.getProbabilityCol().equals("probability")); + Assert.assertEquals(0.5, model.getThreshold(), eps); + Assert.assertEquals("features", model.getFeaturesCol()); + Assert.assertEquals("prediction", model.getPredictionCol()); + Assert.assertEquals("probability", model.getProbabilityCol()); } @Test @@ -85,17 +86,19 @@ public void logisticRegressionWithSetters() { .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); LogisticRegression parent = (LogisticRegression) model.parent(); - assert(parent.getMaxIter() == 10); - assert(parent.getRegParam() == 1.0); - assert(parent.getThreshold() == 0.6); - assert(model.getThreshold() == 0.6); + Assert.assertEquals(10, parent.getMaxIter()); + Assert.assertEquals(1.0, parent.getRegParam(), eps); + Assert.assertEquals(0.4, parent.getThresholds()[0], eps); + Assert.assertEquals(0.6, parent.getThresholds()[1], eps); + Assert.assertEquals(0.6, parent.getThreshold(), eps); + Assert.assertEquals(0.6, model.getThreshold(), eps); // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { - assert(r.getDouble(0) == 0.0); + Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) @@ -105,17 +108,17 @@ public void logisticRegressionWithSetters() { for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } - assert(foundNonZero); + Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); - assert(parent2.getMaxIter() == 5); - assert(parent2.getRegParam() == 0.1); - assert(parent2.getThreshold() == 0.4); - assert(model2.getThreshold() == 0.4); - assert(model2.getProbabilityCol().equals("theProb")); + Assert.assertEquals(5, parent2.getMaxIter()); + Assert.assertEquals(0.1, parent2.getRegParam(), eps); + Assert.assertEquals(0.4, parent2.getThreshold(), eps); + Assert.assertEquals(0.4, model2.getThreshold(), eps); + Assert.assertEquals("theProb", model2.getProbabilityCol()); } @SuppressWarnings("unchecked") @@ -123,18 +126,18 @@ public void logisticRegressionWithSetters() { public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); - assert(model.numClasses() == 2); + Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); for (Row row: trans1.collect()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); - assert(raw.size() == 2); - assert(prob.size() == 2); + Assert.assertEquals(raw.size(), 2); + Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); - assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); - assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps); + Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); @@ -143,8 +146,17 @@ public void logisticRegressionPredictorClassifierMethods() { Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); for (int i = 0; i < prob.size(); ++i) { - assert(probOfPred >= prob.apply(i)); + Assert.assertTrue(probOfPred >= prob.apply(i)); } } } + + @Test + public void logisticRegressionTrainingSummary() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + + LogisticRegressionTrainingSummary summary = model.summary(); + Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length); + } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java new file mode 100644 index 000000000000..ec6b4bf3c0f8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaMultilayerPerceptronClassifierSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + sqlContext = null; + } + + @Test + public void testMLPC() { + DataFrame dataFrame = sqlContext.createDataFrame( + jsc.parallelize(Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), + LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() + .setLayers(new int[] {2, 5, 2}) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100); + MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); + DataFrame result = model.transform(dataFrame); + Row[] predictionAndLabels = result.select("prediction", "label").collect(); + for (Row r: predictionAndLabels) { + Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index a700c9cddb20..f5f690eabd12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -18,11 +18,13 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -58,37 +60,36 @@ public void validatePrediction(DataFrame predictionAndLabels) { for (Row r : predictionAndLabels.collect()) { double prediction = r.getAs(0); double label = r.getAs(1); - assert(prediction == label); + assertEquals(label, prediction, 1E-5); } } @Test public void naiveBayesDefaultParams() { NaiveBayes nb = new NaiveBayes(); - assert(nb.getLabelCol() == "label"); - assert(nb.getFeaturesCol() == "features"); - assert(nb.getPredictionCol() == "prediction"); - assert(nb.getSmoothing() == 1.0); - assert(nb.getModelType() == "multinomial"); + assertEquals("label", nb.getLabelCol()); + assertEquals("features", nb.getFeaturesCol()); + assertEquals("prediction", nb.getPredictionCol()); + assertEquals(1.0, nb.getSmoothing(), 1E-5); + assertEquals("multinomial", nb.getModelType()); } @Test public void testNaiveBayes() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + List data = Arrays.asList( RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), - RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) - )); + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(jrdd, schema); + DataFrame dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index a1ee55415237..cbabafe1b541 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; -import static scala.collection.JavaConversions.seqAsJavaList; +import scala.collection.JavaConverters; import org.junit.After; import org.junit.Assert; @@ -47,16 +47,17 @@ public void setUp() { jsql = new SQLContext(jsc); int nPoints = 3; - // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2. // As a result, we are drawing samples from probability distribution of an actual model. - double[] weights = { + double[] coefficients = { -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = seqAsJavaList(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42)); + List points = JavaConverters.seqAsJavaListConverter( + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + ).asJava(); datasetRDD = jsc.parallelize(points, 2); dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 32d0b3856b7e..a66a1e12927b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d5bd230a957a..8a1e5ef01565 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,16 +55,16 @@ public void tearDown() { public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) - )); StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + DataFrame dataset = jsql.createDataFrame( + Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2)), + schema); Bucketizer bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 845eed61c45c..39da47381b12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; import org.junit.Assert; @@ -56,12 +57,11 @@ public void tearDown() { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.dense(input)) - )); - DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ - new StructField("vec", (new VectorUDT()), false, Metadata.empty()) - })); + DataFrame dataset = jsql.createDataFrame( + Arrays.asList(RowFactory.create(Vectors.dense(input))), + new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 599e9cfd23ad..d12332c2a02a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -17,7 +17,9 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; +import java.util.List; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,17 +56,17 @@ public void tearDown() { @Test public void hashingTF() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + List data = Arrays.asList( RowFactory.create(0.0, "Hi I heard about Spark"), RowFactory.create(0.0, "I wish Java could use case classes"), RowFactory.create(1.0, "Logistic regression models are neat") - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + DataFrame sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index d82f3b7e8c07..e17d549c5059 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature; -import java.util.List; +import java.util.Arrays; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -48,13 +48,12 @@ public void tearDown() { @Test public void normalizer() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + JavaRDD points = jsc.parallelize(Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) - ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), - VectorIndexerSuite.FeatureData.class); + )); + DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 5cf43fec6f29..e8f329f9cf29 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -78,7 +78,7 @@ public Vector getExpected() { @Test public void testPCA() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}), Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 5e8211c2c511..bf8eefd71905 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -17,7 +17,9 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; +import java.util.List; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -59,7 +61,7 @@ public void polynomialExpansionTest() { .setOutputCol("polyFeatures") .setDegree(3); - JavaRDD data = jsc.parallelize(Lists.newArrayList( + List data = Arrays.asList( RowFactory.create( Vectors.dense(-2.0, 2.3), Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) @@ -69,7 +71,7 @@ public void polynomialExpansionTest() { Vectors.dense(0.6, -1.1), Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331) ) - )); + ); StructType schema = new StructType(new StructField[] { new StructField("features", new VectorUDT(), false, Metadata.empty()), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 74eb2733f06e..ed74363f59e3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -17,9 +17,9 @@ package org.apache.spark.ml.feature; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void standardScaler() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java new file mode 100644 index 000000000000..848d9f8aa928 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +public class JavaStopWordsRemoverSuite { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + List data = Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + ); + StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + remover.transform(dataset).collect(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 35b18c5308f6..b2df79ba74fe 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -56,9 +57,9 @@ public void testStringIndexer() { createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); - JavaRDD rdd = jsc.parallelize( - Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + List data = Arrays.asList( + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); + DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") @@ -66,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 3806f650025b..c407d98f1b79 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -52,9 +53,11 @@ public void regexTokenizer() { .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) + .setToLowercase(false) .setMinTokenLength(3); - JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + + JavaRDD rdd = jsc.parallelize(Arrays.asList( new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index b7c564caad3b..e28377757093 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -65,8 +65,7 @@ public void testVectorAssembler() { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index c7ae5468b942..bfcca62fa1c9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -26,8 +27,6 @@ import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; @@ -52,7 +51,7 @@ public void tearDown() { @Test public void vectorIndexerAPI() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new FeatureData(Vectors.dense(0.0, -2.0)), new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java new file mode 100644 index 000000000000..00174e6a683d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; + + +public class JavaVectorSlicerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void vectorSlice() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + List data = Arrays.asList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + ); + + DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + + DataFrame output = vectorSlicer.transform(dataset); + + for (Row r : output.select("userFeatures", "features").take(2)) { + Vector features = r.getAs(1); + Assert.assertEquals(features.size(), 2); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 39c70157f83c..0c0c1c4d12d0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -50,15 +51,15 @@ public void tearDown() { @Test public void testJavaWord2Vec() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) - )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + DataFrame documentDF = sqlContext.createDataFrame( + Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))), + schema); Word2Vec word2Vec = new Word2Vec() .setInputCol("text") diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index 9890155e9f86..fa777f3d42a9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.param; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -61,7 +62,7 @@ public void testParamValidate() { ParamValidators.ltEq(1.0); ParamValidators.inRange(0, 1, true, false); ParamValidators.inRange(0, 1); - ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); - ParamValidators.inArray(Lists.newArrayList("a", "b")); + ParamValidators.inArray(Arrays.asList(0, 1, 3)); + ParamValidators.inArray(Arrays.asList("a", "b")); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index dc6ce8061f62..65841182df9b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -17,10 +17,9 @@ package org.apache.spark.ml.param; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; - import org.apache.spark.ml.util.Identifiable$; /** @@ -89,7 +88,7 @@ private void init() { myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); - List validStrings = Lists.newArrayList("a", "b"); + List validStrings = Arrays.asList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index d591a456864e..4fb0b0d1092b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -60,7 +60,8 @@ public void tearDown() { @Test public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); - assert(lr.getLabelCol().equals("label")); + assertEquals("label", lr.getLabelCol()); + assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); @@ -75,7 +76,7 @@ public void linearRegressionWithSetters() { // Set params, train, and check as many params as we can. LinearRegression lr = new LinearRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0).setSolver("l-bfgs"); LinearRegressionModel model = lr.fit(dataset); LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index e306ebadfe7c..a00ce5e249c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java new file mode 100644 index 000000000000..2976b38e4503 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm; + +import java.io.File; +import java.io.IOException; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseVector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + private File tempDir; + private String path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + sqlContext = new SQLContext(jsc); + + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + File file = new File(tempDir, "part-00000"); + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, file, Charsets.US_ASCII); + path = tempDir.toURI().toString(); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + Utils.deleteRecursively(tempDir); + } + + @Test + public void verifyLibSVMDF() { + DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + .load(path); + Assert.assertEquals("label", dataset.columns()[0]); + Assert.assertEquals("features", dataset.columns()[1]); + Row r = dataset.first(); + Assert.assertEquals(1.0, r.getDouble(0), 1e-15); + DenseVector v = r.getAs(1); + Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 000000000000..01ff1ea65861 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + SQLContext sqlContext = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + SQLContext.clearActive(); + sqlContext = new SQLContext(jsc); + SQLContext.setActive(sqlContext); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + sqlContext = null; + SQLContext.clearActive(); + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.save(outputPath); + try { + instance.save(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 55787f8606d4..c9e5ee22f327 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.classification; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java new file mode 100644 index 000000000000..a714620ff7e4 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaBisectingKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void twoDimensionalData() { + JavaRDD points = sc.parallelize(Lists.newArrayList( + Vectors.dense(4, -1), + Vectors.dense(4, 1), + Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + ), 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(2) + .setSeed(1L); + BisectingKMeansModel model = bkm.run(points); + Assert.assertEquals(3, model.k()); + Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child: model.root().children()) { + double[] center = child.center().toArray(); + if (center[0] > 2) { + Assert.assertEquals(2, child.size()); + Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + } else { + Assert.assertEquals(1, child.size()); + Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + } + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 467a7a69e8f3..123f78da54e3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -18,9 +18,9 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void runGaussianMixture() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 31676e64025d..ad06676c72ac 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import org.junit.After; @@ -25,8 +26,6 @@ import org.junit.Test; import static org.junit.Assert.*; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; @@ -48,7 +47,7 @@ public void tearDown() { @Test public void runKMeansUsingStaticMethods() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -67,7 +66,7 @@ public void runKMeansUsingStaticMethods() { @Test public void runKMeansUsingConstructor() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -90,7 +89,7 @@ public void runKMeansUsingConstructor() { @Test public void testPredictJavaRDD() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c8576..225a216270b3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -22,12 +22,14 @@ import java.util.Arrays; import scala.Tuple2; +import scala.Tuple3; import org.junit.After; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertArrayEquals; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; @@ -44,9 +46,9 @@ public class JavaLDASuite implements Serializable { public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); ArrayList> tinyCorpus = new ArrayList>(); - for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), - LDASuite$.MODULE$.tinyCorpus()[i]._2())); + for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); @@ -60,7 +62,7 @@ public void tearDown() { @Test public void localLDAModel() { - Matrix topics = LDASuite$.MODULE$.tinyTopics(); + Matrix topics = LDASuite.tinyTopics(); double[] topicConcentration = new double[topics.numRows()]; Arrays.fill(topicConcentration, 1.0D / topics.numRows()); LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); @@ -110,8 +112,8 @@ public void distributedLDAModel() { assertEquals(roundedLocalTopicSummary.length, k); // Check: log probabilities - assert(model.logLikelihood() < 0.0); - assert(model.logPrior() < 0.0); + assertTrue(model.logLikelihood() < 0.0); + assertTrue(model.logPrior() < 0.0); // Check: topic distributions JavaPairRDD topicDistributions = model.javaTopicDistributions(); @@ -124,10 +126,25 @@ public Boolean call(Tuple2 tuple2) { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first(); + Long docId = topTopics._1(); // confirm doc ID type + int[] topicIndices = topTopics._2(); + double[] topicWeights = topTopics._3(); + assertEquals(3, topicIndices.length); + assertEquals(3, topicWeights.length); + + // Check: topTopicAssignments + Tuple3 topicAssignment = model.javaTopicAssignments().first(); + Long docId2 = topicAssignment._1(); + int[] termIndices2 = topicAssignment._2(); + int[] topicIndices2 = topicAssignment._3(); + assertEquals(termIndices2.length, topicIndices2.length); } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; @@ -160,11 +177,31 @@ public void OnlineOptimizerCompatibility() { assertEquals(roundedLocalTopicSummary.length, k); } - private static int tinyK = LDASuite$.MODULE$.tinyK(); - private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); - private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + + private static int tinyK = LDASuite.tinyK(); + private static int tinyVocabSize = LDASuite.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite$.MODULE$.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite.toyModel(); + private ArrayList> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 3b0e879eec77..d644766d1e54 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( Vectors.dense(1.0), Vectors.dense(0.0)); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() .setK(1) .setDecayFactor(1.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index effc8a1a6dab..fa4d334801ce 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -18,12 +18,12 @@ package org.apache.spark.mllib.evaluation; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import scala.Tuple2; import scala.Tuple2$; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -34,18 +34,18 @@ public class JavaRankingMetricsSuite implements Serializable { private transient JavaSparkContext sc; - private transient JavaRDD, ArrayList>> predictionAndLabels; + private transient JavaRDD, List>> predictionAndLabels; @Before public void setUp() { sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); - predictionAndLabels = sc.parallelize(Lists.newArrayList( + predictionAndLabels = sc.parallelize(Arrays.asList( Tuple2$.MODULE$.apply( - Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)), + Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), Tuple2$.MODULE$.apply( - Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)), + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Tuple2$.MODULE$.apply( - Lists.newArrayList(1, 2, 3, 4, 5), Lists.newArrayList())), 2); + Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); } @After diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index fbc26167ce66..8a320afa4b13 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -18,14 +18,13 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -50,10 +49,10 @@ public void tfIdf() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(); @@ -70,10 +69,10 @@ public void tfIdfMinimumDocumentFrequency() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(2); diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index fb7afe8c6434..e13ed07e283d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.base.Strings; import org.junit.After; import org.junit.Assert; @@ -51,8 +51,8 @@ public void tearDown() { public void word2Vec() { // The tests are to check Java compatibility. String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); - List words = Lists.newArrayList(sentence.split(" ")); - List> localDoc = Lists.newArrayList(words, words); + List words = Arrays.asList(sentence.split(" ")); + List> localDoc = Arrays.asList(words, words); JavaRDD> doc = sc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index b3815ae6039c..2bef7a860975 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -17,17 +17,16 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; +import java.util.Arrays; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; - public class JavaAssociationRulesSuite implements Serializable { private transient JavaSparkContext sc; @@ -46,10 +45,10 @@ public void tearDown() { public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList( + JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( new FreqItemset(new String[] {"a"}, 15L), new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 18L) + new FreqItemset(new String[] {"a", "b"}, 12L) )); JavaRDD> results = (new AssociationRules()).run(freqItemsets); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 9ce2c52dca8b..154f75d75e4a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -18,13 +18,12 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; @@ -48,13 +47,13 @@ public void tearDown() { public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("r z h k p".split(" ")), - Lists.newArrayList("z y x w v u t s".split(" ")), - Lists.newArrayList("s x o n r".split(" ")), - Lists.newArrayList("x z y m t s q e".split(" ")), - Lists.newArrayList("z".split(" ")), - Lists.newArrayList("x z y r q t p".split(" "))), 2); + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); FPGrowthModel model = new FPGrowth() .setMinSupport(0.5) diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 3349c5022423..8beea102efd0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -80,10 +80,10 @@ public void diagonalMatrixConstruction() { assertArrayEquals(sd.toArray(), s.toArray(), 0.0); assertArrayEquals(s.toArray(), ss.toArray(), 0.0); assertArrayEquals(s.values(), ss.values(), 0.0); - assert(s.values().length == 2); - assert(ss.values().length == 2); - assert(s.colPtrs().length == 4); - assert(ss.colPtrs().length == 4); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); } @Test @@ -137,27 +137,27 @@ public void concatenateMatrices() { Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - assert(deHorz1.numRows() == 3); - assert(deHorz2.numRows() == 3); - assert(deHorz3.numRows() == 3); - assert(spHorz.numRows() == 3); - assert(deHorz1.numCols() == 5); - assert(deHorz2.numCols() == 5); - assert(deHorz3.numCols() == 5); - assert(spHorz.numCols() == 5); + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - assert(deVert1.numRows() == 5); - assert(deVert2.numRows() == 5); - assert(deVert3.numRows() == 5); - assert(spVert.numRows() == 5); - assert(deVert1.numCols() == 2); - assert(deVert2.numCols() == 2); - assert(deVert3.numCols() == 2); - assert(spVert.numCols() == 2); + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 1421067dc61e..77c8c6274f37 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -18,11 +18,10 @@ package org.apache.spark.mllib.linalg; import java.io.Serializable; +import java.util.Arrays; import scala.Tuple2; -import com.google.common.collect.Lists; - import org.junit.Test; import static org.junit.Assert.*; @@ -37,7 +36,7 @@ public void denseArrayConstruction() { @Test public void sparseArrayConstruction() { @SuppressWarnings("unchecked") - Vector v = Vectors.sparse(3, Lists.>newArrayList( + Vector v = Vectors.sparse(3, Arrays.asList( new Tuple2(0, 2.0), new Tuple2(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fcc13c00cbdc..5728df5aeebd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,7 +17,9 @@ package org.apache.spark.mllib.random; -import com.google.common.collect.Lists; +import java.io.Serializable; +import java.util.Arrays; + import org.apache.spark.api.java.JavaRDD; import org.junit.Assert; import org.junit.After; @@ -51,7 +53,7 @@ public void testUniformRDD() { JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -64,7 +66,7 @@ public void testNormalRDD() { JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -79,7 +81,7 @@ public void testLNormalRDD() { JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -93,7 +95,7 @@ public void testPoissonRDD() { JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -107,7 +109,7 @@ public void testExponentialRDD() { JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -122,7 +124,7 @@ public void testGammaRDD() { JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -138,7 +140,7 @@ public void testUniformVectorRDD() { JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -154,7 +156,7 @@ public void testNormalVectorRDD() { JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -172,7 +174,7 @@ public void testLogNormalVectorRDD() { JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -189,7 +191,7 @@ public void testPoissonVectorRDD() { JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -206,7 +208,7 @@ public void testExponentialVectorRDD() { JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -224,10 +226,56 @@ public void testGammaVectorRDD() { JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index af688c504cf1..271dda4662e0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -18,12 +18,12 @@ package org.apache.spark.mllib.recommendation; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import scala.Tuple2; import scala.Tuple3; -import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; import org.junit.After; import org.junit.Assert; @@ -56,8 +56,7 @@ void validatePrediction( double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - List> localUsersProducts = - Lists.newArrayListWithCapacity(users * products); + List> localUsersProducts = new ArrayList(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { localUsersProducts.add(new Tuple2(u, p)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index d38fc91ace3c..32c2f4f3395b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -18,11 +18,12 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import scala.Tuple3; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -36,7 +37,7 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; private List> generateIsotonicInput(double[] labels) { - List> input = Lists.newArrayList(); + ArrayList> input = new ArrayList(labels.length); for (int i = 1; i <= labels.length; i++) { input.add(new Tuple3(labels[i-1], (double) i, 1d)); @@ -77,7 +78,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertTrue(predictions.get(0) == 1d); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index 899c4ea60786..dbf6488d4108 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -59,16 +59,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c98..66b2ceacb05f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -18,39 +18,94 @@ package org.apache.spark.mllib.stat; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.BinarySample; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("JavaStatistics") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + sc = new JavaSparkContext(conf); + ssc = new JavaStreamingContext(sc, new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After public void tearDown() { - sc.stop(); + ssc.stop(); + ssc = null; sc = null; } @Test public void testCorr() { - JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } + + @Test + public void streamingTest() { + List trainingBatch = Arrays.asList( + new BinarySample(true, 1.0), + new BinarySample(false, 2.0)); + JavaDStream training = + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + int numBatches = 2; + StreamingTest model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod("welch"); + model.registerStream(training); + attachTestOutputStream(training); + runStreams(ssc, numBatches, numBatches); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 63d2fa31c749..8c8676745636 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -19,16 +19,21 @@ package org.apache.spark.ml import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType -class PipelineSuite extends SparkFunSuite { +class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { abstract class MyModel extends Model[MyModel] @@ -65,6 +70,8 @@ class PipelineSuite extends SparkFunSuite { .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) + MLTestingUtils.checkCopy(pipelineModel) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) @@ -108,4 +115,105 @@ class PipelineSuite extends SparkFunSuite { assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.stages === stages) } + + test("Pipeline read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = new Pipeline().setStages(Array(writableStage)) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.getStages.length === 1) + assert(pipeline2.getStages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("Pipeline read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) + withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } + + test("PipelineModel read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = + new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer])) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.stages.length === 1) + assert(pipeline2.stages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("PipelineModel read/write: getStagePath") { + val stageUid = "myStage" + val stagesDir = new Path("pipeline", "stages").toString + def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { + val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir) + val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString + assert(path === expected) + } + testStage(0, 1, "0") + testStage(0, 9, "0") + testStage(0, 10, "00") + testStage(1, 10, "01") + testStage(12, 999, "012") + } + + test("PipelineModel read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = + new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer])) + withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } +} + + +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + def getIntParam: Int = $(intParam) + + def setIntParam(value: Int): this.type = set(intParam, value) + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) + + override def write: MLWriter = new DefaultParamsWriter(this) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema +} + +object WritableStage extends MLReadable[WritableStage] { + + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] + + override def load(path: String): WritableStage = super.load(path) +} + +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +class UnWritableStage(override val uid: String) extends Transformer { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 000000000000..d0e3fe7ad14b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +object ClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "rawPredictionCol" -> "myRawPrediction" + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c7bbf1ce07a2..fda2711fed0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -58,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -71,7 +72,8 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setImpurity("gini") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } @@ -212,7 +214,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxBins(2) .setMaxDepth(2) .setMinInstancesPerNode(2) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val numClasses = 2 compareAPIs(rdd, dt, categoricalFeatures, numClasses) } @@ -244,6 +246,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) .collect() @@ -257,6 +262,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -293,6 +311,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { dt: DecisionTreeClassifier, categoricalFeatures: Map[Int, Int], numClasses: Int): Unit = { + val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) @@ -301,5 +320,6 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) + assert(newTree.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index d4b5896c12c0..039141aeb6f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -58,8 +59,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), - Array(1.0)) + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(1.0), 1) ParamsSuite.checkParams(model) } @@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setCheckpointInterval(2) val model = gbt.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + sc.checkpointDir = None Utils.deleteRecursively(tempDir) } @@ -141,7 +145,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { */ } -private object GBTClassifierSuite { +private object GBTClassifierSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -152,6 +156,7 @@ private object GBTClassifierSuite { validationData: Option[RDD[LabeledPoint]], gbt: GBTClassifier, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(oldBoostingStrategy) @@ -160,7 +165,9 @@ private object GBTClassifierSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) + assert(oldModelAsNew.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b7dd44753896..1087afb0cdf7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.ml.classification +import scala.language.existentials +import scala.util.Random + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ @@ -42,24 +49,24 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.classification.LogisticRegressionSuite val nPoints = 10000 - val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 1) + coefficients, xMean, xVariance, true, nPoints, 42), 1) data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 - val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } @@ -76,6 +83,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") + assert(lr.getWeightCol === "") assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) @@ -91,11 +99,64 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + + test("setThreshold, getThreshold") { + val lr = new LogisticRegression + // default + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future + lr.getThresholds + } + } + // Set via threshold. + // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. + lr.setThreshold(1.0) + assert(lr.getThresholds === Array(0.0, 1.0)) + lr.setThreshold(0.0) + assert(lr.getThresholds === Array(1.0, 0.0)) + lr.setThreshold(0.5) + assert(lr.getThresholds === Array(0.5, 0.5)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) + val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } + } + test("logistic regression doesn't fit intercept when fitIntercept is off") { val lr = new LogisticRegression lr.setFitIntercept(false) val model = lr.fit(dataset) assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { @@ -123,14 +184,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + model.transform(dataset, model.threshold -> 0.0, + model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + lr.setThresholds(Array(0.6, 0.4)) + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) @@ -146,6 +209,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model = lr.fit(dataset) assert(model.numClasses === 2) + val numFeatures = dataset.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) val threshold = model.getThreshold val results = model.transform(dataset) @@ -171,43 +236,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("MultiClassSummarizer") { val summarizer1 = (new MultiClassSummarizer) .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) - assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1)) assert(summarizer1.countInvalid === 0) assert(summarizer1.numClasses === 7) val summarizer2 = (new MultiClassSummarizer) .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) - assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1)) assert(summarizer2.countInvalid === 0) assert(summarizer2.numClasses === 6) val summarizer3 = (new MultiClassSummarizer) .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) - assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3)) assert(summarizer3.countInvalid === 3) assert(summarizer3.numClasses === 5) val summarizer4 = (new MultiClassSummarizer) .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) - assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.histogram === Array[Double](0, 1, 1, 1)) assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) - assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1)) assert(summarizerA.countInvalid === 0) assert(summarizerA.numClasses === 7) // large map merges small one val summarizerB = summarizer3.merge(summarizer4) assert(summarizerB.hashCode() === summarizer3.hashCode()) - assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3)) assert(summarizerB.countInvalid === 5) assert(summarizerB.numClasses === 5) } + test("MultiClassSummarizer with weighted samples") { + val summarizer1 = (new MultiClassSummarizer) + .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1) + assert(Vectors.dense(summarizer1.histogram) ~== + Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0) + assert(Vectors.dense(summarizer2.histogram) ~== + Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer = summarizer1.merge(summarizer2) + assert(Vectors.dense(summarizer.histogram) ~== + Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10) + assert(summarizer.countInvalid === 0) + assert(summarizer.numClasses === 7) + } + test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) @@ -222,8 +309,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -234,14 +321,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7996864 */ val interceptR = 2.8366423 - val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) + assert(model1.coefficients ~= coefficientsR relTol 1E-3) // Without regularization, with or without standardization will converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) } test("binary logistic regression without intercept without regularization") { @@ -258,9 +345,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -271,14 +358,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7407946 */ val interceptR = 0.0 - val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { @@ -297,8 +384,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -309,10 +396,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.02481551 */ val interceptR1 = -0.05627428 - val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.weights ~= weightsR1 absTol 2E-2) + assert(model1.coefficients ~= coefficientsR1 absTol 2E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -321,9 +408,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -334,10 +421,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.3722152 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression without intercept with L1 regularization") { @@ -356,9 +443,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -369,10 +456,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.03891782 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 absTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 absTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -381,9 +468,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -394,10 +481,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { @@ -416,8 +503,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -428,10 +515,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.10062872 */ val interceptR1 = 0.15021751 - val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -440,9 +527,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -453,10 +540,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.06266838 */ val interceptR2 = 0.48657516 - val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { @@ -475,9 +562,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -488,10 +575,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.09799775 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -500,9 +587,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -513,10 +600,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.053314311 */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { @@ -535,8 +622,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -547,10 +634,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.15458796 */ val interceptR1 = 0.57734851 - val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) + val coefficientsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.weights ~== weightsR1 absTol 5E-3) + assert(model1.coefficients ~== coefficientsR1 absTol 5E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -559,9 +646,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -572,10 +659,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.05350074 */ val interceptR2 = 0.51555993 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { @@ -594,9 +681,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -607,10 +694,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.142534158 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) + val coefficientsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 absTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -619,9 +706,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -632,10 +719,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { @@ -658,8 +745,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { }).histogram /* - For binary logistic regression with strong L1 regularization, all the weights will be zeros. - As a result, + For binary logistic regression with strong L1 regularization, all the coefficients + will be zeros. As a result, {{{ P(0) = 1 / (1 + \exp(b)), and P(1) = \exp(b) / (1 + \exp(b)) @@ -668,14 +755,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) - val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) + val interceptTheory = math.log(histogram(1) / histogram(0)) + val coefficientsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) - assert(model1.weights ~= weightsTheory absTol 1E-6) + assert(model1.coefficients ~= coefficientsTheory absTol 1E-6) assert(model2.intercept ~== interceptTheory relTol 1E-5) - assert(model2.weights ~= weightsTheory absTol 1E-6) + assert(model2.coefficients ~= coefficientsTheory absTol 1E-6) /* Using the following R code to load the data and train the model using glmnet package. @@ -684,8 +771,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -696,9 +783,135 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR = -0.248065 - val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) + val coefficientsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights ~= weightsR absTol 1E-6) + assert(model1.coefficients ~== coefficientsR absTol 1E-6) + } + + test("evaluate on test set") { + // Evaluate on test set should be same as that of the transformed training data. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(summary.areaUnderROC === sameSummary.areaUnderROC) + assert(summary.roc.collect() === sameSummary.roc.collect()) + assert(summary.pr.collect === sameSummary.pr.collect()) + assert( + summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) + assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + assert( + summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + } + + test("statistics on training data") { + // Test that loss is monotonically decreasing. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("binary logistic regression with weighted samples") { + val (dataset, weightedDataset) = { + val nPoints = 1000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + + // Let's over-sample the positive samples twice. + val data1 = testData.flatMap { case labeledPoint: LabeledPoint => + if (labeledPoint.label == 1.0) { + Iterator(labeledPoint, labeledPoint) + } else { + Iterator(labeledPoint) + } + } + + val rnd = new Random(8392) + val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) => + if (rnd.nextGaussian() > 0.0) { + if (label == 1.0) { + Iterator( + Instance(label, 1.2, features), + Instance(label, 0.8, features), + Instance(0.0, 0.0, features)) + } else { + Iterator( + Instance(label, 0.3, features), + Instance(1.0, 0.0, features), + Instance(label, 0.1, features), + Instance(label, 0.6, features)) + } + } else { + if (label == 1.0) { + Iterator(Instance(label, 2.0, features)) + } else { + Iterator(Instance(label, 1.0, features)) + } + } + } + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LogisticRegression).setFitIntercept(true) + .setRegParam(0.0).setStandardization(true) + val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") + .setRegParam(0.0).setStandardization(true) + val model1a0 = trainer1a.fit(dataset) + val model1a1 = trainer1a.fit(weightedDataset) + val model1b = trainer1b.fit(weightedDataset) + assert(model1a0.coefficients !~= model1a1.coefficients absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + } + + test("read/write") { + def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + assert(model.numClasses === model2.numClasses) + assert(model.numFeatures === model2.numFeatures) + } + val lr = new LogisticRegression() + testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LogisticRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6), + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index ddc948f65df4..a326432d017f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row @@ -53,16 +53,17 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp test("3 class classification with 2 hidden layers") { val nPoints = 1000 - // The following weights are taken from OneVsRestSuite.scala + // The following coefficients are taken from OneVsRestSuite.scala // they represent 3-class iris dataset - val weights = Array( + val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 1), 2) val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 @@ -70,9 +71,11 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(11L) // currently this seed is ignored .setMaxIter(numIterations) val model = trainer.fit(dataFrame) + val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") .map { case Row(p: Double, l: Double) => (p, l) } // train multinomial logistic regression diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index aea3d9b69449..082a6bcd211a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,17 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: DataFrame = _ -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + override def beforeAll(): Unit = { + super.beforeAll() - import NaiveBayes.{Multinomial, Bernoulli} + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -163,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 3775292f6dca..5ea71c5317b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -43,16 +43,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val nPoints = 1000 - // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2 // As a result, we are drawing samples from probability distribution of an actual model. - val weights = Array( + val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 42), 2) dataset = sqlContext.createDataFrame(rdd) } @@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + assert(ovaModel.models.size === numClasses) val transformedDataset = ovaModel.transform(dataset) @@ -151,7 +155,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1))) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala new file mode 100644 index 000000000000..cfa75ecf387c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +final class TestProbabilisticClassificationModel( + override val uid: String, + override val numFeatures: Int, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { + + override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra) + + override protected def predictRaw(input: Vector): Vector = { + input + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction + } + + def friendlyPredict(input: Vector): Double = { + predict(input) + } +} + + +class ProbabilisticClassifierSuite extends SparkFunSuite { + + test("test thresholding") { + val thresholds = Array(0.5, 0.2) + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(thresholds) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + } + + test("test thresholding not required") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + } +} + +object ProbabilisticClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6) + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index edf848b21a90..deb8ec771cb2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -67,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) .collect() @@ -149,6 +153,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -176,7 +209,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte */ } -private object RandomForestClassifierSuite { +private object RandomForestClassifierSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -187,6 +220,7 @@ private object RandomForestClassifierSuite { rf: RandomForestClassifier, categoricalFeatures: Map[Int, Int], numClasses: Int): Unit = { + val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) val oldModel = OldRandomForest.trainClassifier( @@ -200,6 +234,7 @@ private object RandomForestClassifierSuite { TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) - assert(newModel.numClasses == numClasses) + assert(newModel.numClasses === numClasses) + assert(newModel.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 1f15ac02f400..2724e51f31aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -52,10 +44,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getFeaturesCol === "features") assert(kmeans.getPredictionCol === "prediction") assert(kmeans.getMaxIter === 20) - assert(kmeans.getRuns === 1) assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 5) - assert(kmeans.getEpsilon === 1e-4) + assert(kmeans.getTol === 1e-4) } test("set parameters") { @@ -64,21 +55,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setFeaturesCol("test_feature") .setPredictionCol("test_prediction") .setMaxIter(33) - .setRuns(7) .setInitMode(MLlibKMeans.RANDOM) .setInitSteps(3) .setSeed(123) - .setEpsilon(1e-3) + .setTol(1e-3) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") assert(kmeans.getPredictionCol === "test_prediction") assert(kmeans.getMaxIter === 33) - assert(kmeans.getRuns === 7) assert(kmeans.getInitMode === MLlibKMeans.RANDOM) assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) - assert(kmeans.getEpsilon === 1e-3) + assert(kmeans.getTol === 1e-3) } test("parameters validation") { @@ -91,9 +80,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } - intercept[IllegalArgumentException] { - new KMeans().setRuns(0) - } } test("fit & transform") { @@ -110,5 +96,35 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala new file mode 100644 index 000000000000..97dbfd9a4314 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +object LDASuite { + def generateLDAData( + sql: SQLContext, + rows: Int, + k: Int, + vocabSize: Int): DataFrame = { + val avgWC = 1 // average instances of each word in a doc + val sc = sql.sparkContext + val rng = new java.util.Random() + rng.setSeed(1) + val rdd = sc.parallelize(1 to rows).map { i => + Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) + }.map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) +} + + +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + val k: Int = 5 + val vocabSize: Int = 30 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + } + + test("default parameters") { + val lda = new LDA() + + assert(lda.getFeaturesCol === "features") + assert(lda.getMaxIter === 20) + assert(lda.isDefined(lda.seed)) + assert(lda.getCheckpointInterval === 10) + assert(lda.getK === 10) + assert(!lda.isSet(lda.docConcentration)) + assert(!lda.isSet(lda.topicConcentration)) + assert(lda.getOptimizer === "online") + assert(lda.getLearningDecay === 0.51) + assert(lda.getLearningOffset === 1024) + assert(lda.getSubsamplingRate === 0.05) + assert(lda.getOptimizeDocConcentration) + assert(lda.getTopicDistributionCol === "topicDistribution") + } + + test("set parameters") { + val lda = new LDA() + .setFeaturesCol("test_feature") + .setMaxIter(33) + .setSeed(123) + .setCheckpointInterval(7) + .setK(9) + .setTopicConcentration(0.56) + .setTopicDistributionCol("myOutput") + + assert(lda.getFeaturesCol === "test_feature") + assert(lda.getMaxIter === 33) + assert(lda.getSeed === 123) + assert(lda.getCheckpointInterval === 7) + assert(lda.getK === 9) + assert(lda.getTopicConcentration === 0.56) + assert(lda.getTopicDistributionCol === "myOutput") + + + // setOptimizer + lda.setOptimizer("em") + assert(lda.getOptimizer === "em") + lda.setOptimizer("online") + assert(lda.getOptimizer === "online") + lda.setLearningDecay(0.53) + assert(lda.getLearningDecay === 0.53) + lda.setLearningOffset(1027) + assert(lda.getLearningOffset === 1027) + lda.setSubsamplingRate(0.06) + assert(lda.getSubsamplingRate === 0.06) + lda.setOptimizeDocConcentration(false) + assert(!lda.getOptimizeDocConcentration) + } + + test("parameters validation") { + val lda = new LDA() + + // misc Params + intercept[IllegalArgumentException] { + new LDA().setK(1) + } + intercept[IllegalArgumentException] { + new LDA().setOptimizer("no_such_optimizer") + } + intercept[IllegalArgumentException] { + new LDA().setDocConcentration(-1.1) + } + intercept[IllegalArgumentException] { + new LDA().setTopicConcentration(-1.1) + } + + // validateParams() + lda.validateParams() + lda.setDocConcentration(1.1) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) + withClue("LDA docConcentration validity check failed for bad array length") { + intercept[IllegalArgumentException] { + lda.validateParams() + } + } + + // Online LDA + intercept[IllegalArgumentException] { + new LDA().setLearningOffset(0) + } + intercept[IllegalArgumentException] { + new LDA().setLearningDecay(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(1.1) + } + } + + test("fit & transform with Online LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) + val model = lda.fit(dataset) + + MLTestingUtils.checkCopy(model) + + assert(model.isInstanceOf[LocalLDAModel]) + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(!model.isDistributed) + + // transform() + val transformed = model.transform(dataset) + val expectedColumns = Array("features", lda.getTopicDistributionCol) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + transformed.select(lda.getTopicDistributionCol).collect().foreach { r => + val topicDistribution = r.getAs[Vector](0) + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + } + + // logLikelihood, logPerplexity + val ll = model.logLikelihood(dataset) + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPerplexity(dataset) + assert(lp >= 0.0 && lp != Double.PositiveInfinity) + + // describeTopics + val topics = model.describeTopics(3) + assert(topics.count() === k) + assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + topics.select("termIndices").collect().foreach { case r: Row => + val termIndices = r.getAs[Seq[Int]](0) + assert(termIndices.length === 3 && termIndices.toSet.size === 3) + } + topics.select("termWeights").collect().foreach { case r: Row => + val termWeights = r.getAs[Seq[Double]](0) + assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0)) + } + } + + test("fit & transform with EM LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val model_ = lda.fit(dataset) + + MLTestingUtils.checkCopy(model_) + + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(model.isDistributed) + + val localModel = model.toLocal + assert(localModel.isInstanceOf[LocalLDAModel]) + + // training logLikelihood, logPrior + val ll = model.trainingLogLikelihood + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPrior + assert(lp <= 0.0 && lp != Double.NegativeInfinity) + } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe6677..a535c1218ecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b370..7ee65975d22f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 5b203784559e..954d3bedc14b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -63,14 +65,22 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1013829 absTol 0.01) // r2 score evaluator.setMetricName("r2") - assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.9998387 absTol 0.01) // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) + } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 208604398366..6d2d8fe71444 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Double] = _ @@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x === y, "The feature value is not correct after binarization.") } } + + test("read/write") { + val t = new Binarizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setThreshold(0.1) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ec85e0d151e0..9ea7d431763a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,12 +21,13 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Bucketizer) @@ -111,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } + + test("read/write") { + val t = new Bucketizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setSplits(Array(0.1, 0.8, 0.9)) + testDefaultReadWrite(t) + } } private object BucketizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala new file mode 100644 index 000000000000..7827db2794cf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("Test Chi-Square selector") { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val data = Seq( + LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) + ) + + val preFilteredData = Seq( + Vectors.dense(0.0), + Vectors.dense(6.0), + Vectors.dense(8.0), + Vectors.dense(5.0) + ) + + val df = sc.parallelize(data.zip(preFilteredData)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + val model = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("data") + .setLabelCol("label") + .setOutputCol("filtered") + + model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala new file mode 100644 index 000000000000..9c9999017317 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("params") { + ParamsSuite.checkParams(new CountVectorizer) + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + private def split(s: String): Seq[String] = s.split("\\s+") + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, split("a b b c d a"), + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split(""), Vectors.sparse(4, Seq())), // empty string + (4, split("a notInDict d"), + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d e"), + Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), + (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), + (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), + (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .fit(df) + assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer vocabSize and minDF") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq((0, 1.0))))) + ).toDF("id", "words", "expected") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .fit(df) + assert(cvModel.vocabulary === Array("a", "b", "c")) + + // minDF: ignore terms with count less than 3 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3) + .fit(df) + assert(cvModel2.vocabulary === Array("a", "b")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // minDF: ignore terms with freq < 0.75 + val cvModel3 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3.0 / df.count()) + .fit(df) + assert(cvModel3.vocabulary === Array("a", "b")) + + cvModel3.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer throws exception when vocab is empty") { + intercept[IllegalArgumentException] { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a b b c c")), + (1, split("aa bb cc"))) + ).toDF("id", "words") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .setMinDF(3) + .fit(df) + } + } + + test("CountVectorizerModel with minTF count") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq())), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTF freq") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(0.3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala deleted file mode 100644 index e90d9d4ef21f..000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.feature - -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ - -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("params") { - ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) - } - - test("CountVectorizerModel common cases") { - val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), - (1, "a b b c d a".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), - (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string - (4, "a notInDict d".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary - )).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) - .setInputCol("words") - .setOutputCol("features") - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) - } - } - - test("CountVectorizerModel with minTermFreq") { - val df = sqlContext.createDataFrame(Seq( - (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), - (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), - (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) - ).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) - .setInputCol("words") - .setOutputCol("features") - .setMinTermFreq(3) - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) - } - } -} - - diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 37ed2367c33f..0f2aafebafe6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -22,6 +22,7 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -29,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { testDCT(data, inverse) } + test("read/write") { + val t = new DCT() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setInverse(true) + testDefaultReadWrite(t) + } + private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 4157b84b29d0..0dcd0f49465e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new HashingTF) @@ -50,4 +51,12 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) assert(features ~== expected absTol 1e-14) } + + test("read/write") { + val t = new HashingTF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumFeatures(10) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429..bc958c15857b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 000000000000..932d331b472b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } + + test("read/write") { + val t = new Interaction() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c452054bec92..035bfc07b684 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.dense(1, 0, Long.MinValue), Vectors.dense(2, 0, 0), @@ -51,6 +50,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { .foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1.equals(vector2), "Transformed vector is different with expected.") } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("MinMaxScaler arguments max must be larger than min") { @@ -65,4 +67,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index ab97e3dbc6ee..58fda29aa1e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.NGramSuite._ test("default behavior yields bigram features") { @@ -79,6 +80,14 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { ))) testNGram(nGram, dataset) } + + test("read/write") { + val t = new NGram() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setN(3) + testDefaultReadWrite(t) + } } object NGramSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9f03470b7f32..468833901995 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @@ -60,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(3, Seq()) ) - val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") @@ -104,6 +104,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } + + test("read/write") { + val t = new Normalizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setP(3.0) + testDefaultReadWrite(t) + } } private object NormalizerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 321eeb843941..76d12050f967 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col -class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -101,4 +103,12 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } + + test("read/write") { + val t = new OneHotEncoder() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDropLast(false) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index d0ae36b28c7a..9f6618b92929 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,19 +19,20 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val explainedVariance = Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector] + val model = new PCAModel("pca", mat, explainedVariance) ParamsSuite.checkParams(model) } @@ -56,9 +57,28 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { .setK(3) .fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + pca.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } + + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix], + Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 29eebd8960eb..70892dc57170 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -21,12 +21,14 @@ import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PolynomialExpansion) @@ -98,5 +100,13 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext throw new TestFailedException("Unmatched data types after polynomial expansion", 0) } } + + test("read/write") { + val t = new PolynomialExpansion() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDegree(3) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala new file mode 100644 index 000000000000..3a4f6d235aa6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{SparkContext, SparkFunSuite} + +class QuantileDiscretizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ + + test("Test quantile discretizer") { + checkDiscretizedData(sc, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + 10, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) + + checkDiscretizedData(sc, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + 4, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) + + checkDiscretizedData(sc, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + 3, + Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), + Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + + checkDiscretizedData(sc, + Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), + 2, + Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), + Array("-Infinity, 2.0", "2.0, Infinity")) + + } + + test("Test getting splits") { + val splitTestPoints = Array( + Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), + Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), + Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), + Array(Double.NegativeInfinity, Double.PositiveInfinity) + -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), + Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), + Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), + Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) + ) + for ((ori, res) <- splitTestPoints) { + assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + } + } + + test("read/write") { + val t = new QuantileDiscretizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumBuckets(6) + testDefaultReadWrite(t) + } +} + +private object QuantileDiscretizerSuite extends SparkFunSuite { + + def checkDiscretizedData( + sc: SparkContext, + data: Array[Double], + numBucket: Int, + expectedResult: Array[Double], + expectedAttrs: Array[String]): Unit = { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ + + val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") + val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") + .setNumBuckets(numBucket) + val result = discretizer.fit(df).transform(df) + + val transformedFeatures = result.select("result").collect() + .map { case Row(transformedFeature: Double) => transformedFeature } + val transformedAttrs = Attribute.fromStructField(result.schema("result")) + .asInstanceOf[NominalAttribute].values.get + + assert(transformedFeatures === expectedResult, + "Transformed features do not equal expected features.") + assert(transformedAttrs === expectedAttrs, + "Transformed attributes do not equal expected attributes.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 436e66bab09b..53798c659d4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = null) { + schema: StructType = new StructType) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) - assert(resolved.terms == terms) + val simpleTerms = terms.map { t => + if (t.contains(":")) { + t.split(":").toSeq + } else { + Seq(t) + } + } + assert(resolved.terms == simpleTerms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) checkParse("y ~ x + x", "y", Seq("x")) - checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("y~x+z", "y", Seq("x", "z")) + checkParse("y ~ ._fo..o ", "y", Seq("._fo..o")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } @@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite { assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) } + + test("parse interactions") { + checkParse("y ~ a:b", "y", Seq("a:b")) + checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) + checkParse("y ~ foo:bar", "y", Seq("foo:bar")) + checkParse("y ~ a : b : c", "y", Seq("a:b:c")) + checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) + } + + test("parse basic interactions with dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + .add("d", "string", true) + checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema) + checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema) + checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema) + } + + // Test data generated in R with terms.formula(y ~ .:., data = iris) + test("parse all to all iris interactions") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse( + "y ~ .:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Species", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Length:Species", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Sepal.Width:Species", + "Petal.Length:Petal.Width", + "Petal.Length:Species", + "Petal.Width:Species"), + schema) + } + + // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris) + test("parse interaction negation with iris") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse("y ~ .:. - .:.", "y", Nil, schema) + checkParse( + "y ~ .:. - Species:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Petal.Length:Petal.Width"), + schema) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 6aed3243afce..dc20a5ec2152 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.collect() === expected.collect()) } + test("index string label") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), + ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)) + ).toDF("id", "a", "b", "features", "label") + // assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") val original = sqlContext.createDataFrame( @@ -118,9 +137,81 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedAttrs = new AttributeGroup( "features", Array( - new BinaryAttribute(Some("a__bar"), Some(1)), - new BinaryAttribute(Some("a__foo"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(1)), + new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } + + test("numeric interaction") { + val formula = new RFormula().setFormula("a ~ b:c:d") + val original = sqlContext.createDataFrame( + Seq((1, 2, 4, 2), (2, 3, 4, 1)) + ).toDF("a", "b", "c", "d") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) + ).toDF("a", "b", "c", "d", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) + assert(attrs === expectedAttrs) + } + + test("factor numeric interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_baz:b"), Some(1)), + new NumericAttribute(Some("a_bar:b"), Some(2)), + new NumericAttribute(Some("a_foo:b"), Some(3)))) + assert(attrs === expectedAttrs) + } + + test("factor factor interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_bar:b_zq"), Some(1)), + new NumericAttribute(Some("a_bar:b_zz"), Some(2)), + new NumericAttribute(Some("a_foo:b_zq"), Some(3)), + new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) + assert(attrs === expectedAttrs) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 000000000000..553e0b870216 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("read/write") { + val t = new SQLTransformer() + .setStatement("select * from __THIS__") + testDefaultReadWrite(t) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala new file mode 100644 index 000000000000..1eae125a524e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + @transient var data: Array[Vector] = _ + @transient var resWithStd: Array[Vector] = _ + @transient var resWithMean: Array[Vector] = _ + @transient var resWithBoth: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.dense(-2.0, 2.3, 0.0), + Vectors.dense(0.0, -5.1, 1.0), + Vectors.dense(1.7, -0.6, 3.3) + ) + resWithMean = Array( + Vectors.dense(-1.9, 3.433333333333, -1.433333333333), + Vectors.dense(0.1, -3.966666666667, -0.433333333333), + Vectors.dense(1.8, 0.533333333333, 1.866666666667) + ) + resWithStd = Array( + Vectors.dense(-1.079898494312, 0.616834091415, 0.0), + Vectors.dense(0.0, -1.367762550529, 0.590968109266), + Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579) + ) + resWithBoth = Array( + Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497), + Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682), + Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631) + ) + } + + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") + } + } + + test("params") { + ParamsSuite.checkParams(new StandardScaler) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) + } + + test("Standardization with default parameter") { + val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + + val standardScaler0 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .fit(df0) + + assertResult(standardScaler0.transform(df0)) + } + + test("Standardization with setter") { + val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val standardScaler1 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(true) + .setWithStd(true) + .fit(df1) + + val standardScaler2 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(true) + .setWithStd(false) + .fit(df2) + + val standardScaler3 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(false) + .setWithStd(false) + .fit(df3) + + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index f01306f89cb5..fb217e0c1de9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite { } } -class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { +class StopWordsRemoverSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import StopWordsRemoverSuite._ test("StopWordsRemover default") { @@ -65,7 +68,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val stopWords = StopWords.English ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") @@ -77,4 +80,13 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { testStopWordsRemover(remover, dataSet) } + + test("read/write") { + val t = new StopWordsRemover() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setStopWords(Array("the", "a")) + .setCaseSensitive(true) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0295a0fe2fc..749bfac74782 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col -class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class StringIndexerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") { @@ -37,6 +44,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] @@ -47,19 +58,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + } + + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) } test("StringIndexer with a numeric input column") { @@ -88,4 +117,86 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } + + test("StringIndexer, IndexToString are inverses") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val idx2str = new IndexToString() + .setInputCol("labelIndex") + .setOutputCol("sameLabel") + .setLabels(indexer.labels) + idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { + case Row(a: String, b: String) => + assert(a === b) + } + } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } + + test("IndexToString read/write") { + val t = new IndexToString() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setLabels(Array("a", "b", "c")) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index e5fd21c3f6fc..36e8e5d86838 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -21,20 +21,30 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite { +class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) } + + test("read/write") { + val t = new Tokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } -class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("params") { @@ -48,13 +58,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer0.setMinTokenLength(3) @@ -64,11 +74,34 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) testRegexTokenizer(tokenizer2, dataset2) } + + test("RegexTokenizer with toLowercase false") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setToLowercase(false) + val dataset = sqlContext.createDataFrame(Seq( + TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), + TokenizerTestData("java scala", Array("java", "scala")) + )) + testRegexTokenizer(tokenizer, dataset) + } + + test("read/write") { + val t = new RegexTokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTokenLength(2) + .setGaps(false) + .setPattern("hi") + .setToLowercase(false) + testDefaultReadWrite(t) + } } object RegexTokenizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index bb4d5b983e0d..9c1c00f41ab1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new VectorAssembler) @@ -67,6 +69,17 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) @@ -101,4 +114,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) } + + test("read/write") { + val t = new VectorAssembler() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 03120c828ca9..67817fa4baf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,12 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -109,6 +111,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L test("Throws error when given RDDs with different size vectors") { val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work intercept[SparkException] { @@ -246,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala new file mode 100644 index 000000000000..8acc3369c489 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row} + +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("params") { + val slicer = new VectorSlicer + ParamsSuite.checkParams(slicer) + assert(slicer.getIndices.length === 0) + assert(slicer.getNames.length === 0) + withClue("VectorSlicer should not have any features selected by default") { + intercept[IllegalArgumentException] { + slicer.validateParams() + } + } + } + + test("feature validity checks") { + import VectorSlicer._ + assert(validIndices(Array(0, 1, 8, 2))) + assert(validIndices(Array.empty[Int])) + assert(!validIndices(Array(-1))) + assert(!validIndices(Array(1, 2, 1))) + + assert(validNames(Array("a", "b"))) + assert(validNames(Array.empty[String])) + assert(!validNames(Array("", "b"))) + assert(!validNames(Array("a", "b", "a"))) + } + + test("Test vector slicer") { + val data = Array( + Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), + Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3), + Vectors.sparse(5, Seq()) + ) + + // Expected after selecting indices 1, 4 + val expected = Array( + Vectors.sparse(2, Seq((0, 2.3))), + Vectors.dense(2.3, 1.0), + Vectors.dense(0.0, 0.0), + Vectors.dense(-1.1, 3.3), + Vectors.sparse(2, Seq()) + ) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]]) + + val resultAttrs = Array("f1", "f4").map(defaultAttr.withName) + val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) + + val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } + val df = sqlContext.createDataFrame(rdd, + StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) + + val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") + + def validateResults(df: DataFrame): Unit = { + df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 === vec2) + } + val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) + resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => + assert(a === b) + } + } + + vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) + validateResults(vectorSlicer.transform(df)) + } + + test("read/write") { + val t = new VectorSlicer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setIndices(Array(1, 3)) + .setNames(Array("a", "d")) + testDefaultReadWrite(t) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index aa6ce533fd88..d561bbbb2552 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -34,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Word2Vec") { - val sqlContext = new SQLContext(sc) + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -62,10 +64,147 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") + } + } + + test("getVectors") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExpected = Seq( + Vectors.dense(0.3326166272163391, -0.5603077411651611, -0.2309209555387497), + Vectors.dense(0.32463887333869934, -0.9306551218032837, 1.393115520477295), + Vectors.dense(-0.27150997519493103, 0.4372006058692932, -0.13465698063373566) + ) + + realVectors.zip(magicExpected).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") } } + + test("findSynonyms") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + } + + test("window size") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val (synonyms, similarity) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) + } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 778abcba22c1..4e2d0e93bd41 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = new SQLContext(data.sparkContext) + val sqlContext = SQLContext.getOrCreate(data.sparkContext) import sqlContext.implicits._ val df = data.toDF() val numFeatures = data.first().features.size @@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite { "checkEqual failed since the two tree ensembles were not identical") } } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala new file mode 100644 index 000000000000..b542ba3dc54d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + instances = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + } + + test("WLS against lm") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- lm(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -3.727121 3.009983 + [1] 18.08 6.08 -0.60 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727121, 3.009983), + Vectors.dense(18.08, 6.08, -0.60)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.307532 2.924206 + [1] 0.000000 -2.914790 2.840627 + [1] 0.000000 -1.526575 2.558158 + [1] 0.00000000 0.06984238 2.20488344 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 13.5356178 3.2714044 0.3770744 + [1] 14.064629 3.565802 0.269593 + [1] 10.1238013 0.9708569 1.1475466 + [1] 13.1860638 2.1761382 0.6213134 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.307532, 2.924206), + Vectors.dense(0.0, -2.914790, 2.840627), + Vectors.dense(0.0, -1.526575, 2.558158), + Vectors.dense(0.0, 0.06984238, 2.20488344), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(13.5356178, 3.2714044, 0.3770744), + Vectors.dense(14.064629, 3.565802, 0.269593), + Vectors.dense(10.1238013, 0.9708569, 1.1475466), + Vectors.dense(13.1860638, 2.1761382, 0.6213134)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0); + standardizeFeatures <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea01..a1878be747ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -18,13 +18,141 @@ package org.apache.spark.ml.param import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { + test("json encode/decode") { + val dummy = new Params { + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override val uid: String = "dummy" + } + + { // BooleanParam + val param = new BooleanParam(dummy, "name", "doc") + for (value <- Seq(true, false)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntParam + val param = new IntParam(dummy, "name", "doc") + for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // LongParam + val param = new LongParam(dummy, "name", "doc") + for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // FloatParam + val param = new FloatParam(dummy, "name", "doc") + for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f, + Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // DoubleParam + val param = new DoubleParam(dummy, "name", "doc") + for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0, + Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // Param[String] + val param = new Param[String](dummy, "name", "doc") + // Currently we do not support null. + for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // Param[Vector] + val param = new Param[Vector](dummy, "name", "doc") + val values = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0, 2.0), + Vectors.sparse(0, Array.empty, Array.empty), + Vectors.sparse(2, Array(1), Array(2.0))) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntArrayParam + val param = new IntArrayParam(dummy, "name", "doc") + val values: Seq[Array[Int]] = Seq( + Array(), + Array(1), + Array(Int.MinValue, 0, Int.MaxValue)) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // DoubleArrayParam + val param = new DoubleArrayParam(dummy, "name", "doc") + val values: Seq[Array[Double]] = Seq( + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + + { // StringArrayParam + val param = new StringArrayParam(dummy, "name", "doc") + val values: Seq[Array[String]] = Seq( + Array(), + Array(""), + Array("", "1", "abc", "quote\"", "newline\n")) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + } + test("param") { val solver = new TestParams() val uid = solver.uid - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} assert(maxIter.name === "maxIter") assert(maxIter.doc === "maximum number of iterations (>= 0)") @@ -40,6 +168,10 @@ class ParamsSuite extends SparkFunSuite { assert(inputCol.toString === s"${uid}__inputCol") + intercept[java.util.NoSuchElementException] { + solver.getOrDefault(solver.handleInvalid) + } + intercept[IllegalArgumentException] { solver.setMaxIter(-1) } @@ -63,7 +195,7 @@ class ParamsSuite extends SparkFunSuite { test("param map") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} val map0 = ParamMap.empty @@ -102,12 +234,13 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{handleInvalid, inputCol, maxIter} val params = solver.params - assert(params.length === 2) - assert(params(0).eq(inputCol), "params must be ordered by name") - assert(params(1).eq(maxIter)) + assert(params.length === 3) + assert(params(0).eq(handleInvalid), "params must be ordered by name") + assert(params(1).eq(inputCol), "params must be ordered by name") + assert(params(2).eq(maxIter)) assert(!solver.isSet(maxIter)) assert(solver.isDefined(maxIter)) @@ -122,7 +255,7 @@ class ParamsSuite extends SparkFunSuite { assert(solver.explainParam(maxIter) === "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === - Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) @@ -151,6 +284,11 @@ class ParamsSuite extends SparkFunSuite { solver.clearMaxIter() assert(!solver.isSet(maxIter)) + // Re-set and clear maxIter using the generic clear API + solver.setMaxIter(10) + solver.clear(maxIter) + assert(!solver.isSet(maxIter)) + val copied = solver.copy(ParamMap(solver.maxIter -> 50)) assert(copied.uid === solver.uid) assert(copied.getInputCol === solver.getInputCol) @@ -199,6 +337,17 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) + + val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0) + assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1))) + } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 275924834453..9d23547f2844 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter} import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter + with HasInputCol { def this() = this(Identifiable.randomUID("testParams")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb..2c3fb84160dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,27 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -185,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -374,6 +372,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { @@ -479,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 000000000000..d718ef63b531 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.setQuantileProbabilities(Array(0.1, 0.8)) + .setQuantilesCol("quantiles") + .fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction", "quantiles") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getQuantileProbabilities === Array(0.1, 0.8)) + assert(model.getQuantilesCol === "quantiles") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + .setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression w/o quantiles column") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + val outputDf = model.transform(datasetUnivariate) + + assert(outputDf.schema.fieldNames.contains("quantiles") === false) + + outputDf.select("features", "prediction") + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction ~== model.predict(features) relTol 1E-5) + } + } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 33aa9d0d6234..6999a910c34a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -48,7 +49,8 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } @@ -57,10 +59,20 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -78,6 +90,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { data: RDD[LabeledPoint], dt: DecisionTreeRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures) val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) @@ -86,5 +99,6 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) + assert(newTree.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dbdce0c9dea5..09326600e620 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxDepth(2) .setMaxIter(2) val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) val preds = model.transform(df) val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 @@ -151,7 +156,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { */ } -private object GBTRegressorSuite { +private object GBTRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -162,6 +167,7 @@ private object GBTRegressorSuite { validationData: Option[RDD[LabeledPoint]], gbt: GBTRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(oldBoostingStrategy) val oldModel = oldGBT.run(data) @@ -169,7 +175,9 @@ private object GBTRegressorSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures, numFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) + assert(oldModelAsNew.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 66e4b170bae8..f067c29d27a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,57 +19,49 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - private val schema = StructType( - Array( - StructField("label", DoubleType), - StructField("features", DoubleType), - StructField("weight", DoubleType))) - - private val predictionSchema = StructType(Array(StructField("features", DoubleType))) +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) - val parallelData = sc.parallelize(data) - - sqlContext.createDataFrame(parallelData, schema) + sqlContext.createDataFrame( + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - val data = Seq.tabulate(features.size)(i => Row(features(i))) - - val parallelData = sc.parallelize(data) - sqlContext.createDataFrame(parallelData, predictionSchema) + sqlContext.createDataFrame(features.map(Tuple1.apply)) + .toDF("features") } test("isotonic regression predictions") { val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) - val trainer = new IsotonicRegression().setIsotonicParam(true) + val ir = new IsotonicRegression().setIsotonic(true) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val predictions = model .transform(dataset) - .select("prediction").map { - case Row(pred) => pred + .select("prediction").map { case Row(pred) => + pred }.collect() assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) - assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) - assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) - assert(model.parentModel.isotonic) + assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.getIsotonic) } test("antitonic regression predictions") { val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) - val trainer = new IsotonicRegression().setIsotonicParam(false) + val ir = new IsotonicRegression().setIsotonic(false) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) val predictions = model @@ -94,32 +86,38 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val ir = new IsotonicRegression() assert(ir.getLabelCol === "label") assert(ir.getFeaturesCol === "features") - assert(ir.getWeightCol === "weight") assert(ir.getPredictionCol === "prediction") - assert(ir.getIsotonicParam === true) + assert(!ir.isDefined(ir.weightCol)) + assert(ir.getIsotonic) + assert(ir.getFeatureIndex === 0) val model = ir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "features", "prediction", "weight") .collect() assert(model.getLabelCol === "label") assert(model.getFeaturesCol === "features") - assert(model.getWeightCol === "weight") assert(model.getPredictionCol === "prediction") - assert(model.getIsotonicParam === true) + assert(!model.isDefined(model.weightCol)) + assert(model.getIsotonic) + assert(model.getFeatureIndex === 0) assert(model.hasParent) } test("set parameters") { val isotonicRegression = new IsotonicRegression() - .setIsotonicParam(false) - .setWeightParam("w") + .setIsotonic(false) + .setWeightCol("w") .setFeaturesCol("f") .setLabelCol("l") .setPredictionCol("p") - assert(isotonicRegression.getIsotonicParam === false) + assert(!isotonicRegression.getIsotonic) assert(isotonicRegression.getWeightCol === "w") assert(isotonicRegression.getFeaturesCol === "f") assert(isotonicRegression.getLabelCol === "l") @@ -130,7 +128,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val dataset = generateIsotonicInput(Seq(1, 2, 3)) intercept[IllegalArgumentException] { - new IsotonicRegression().setWeightParam("w").fit(dataset) + new IsotonicRegression().setWeightCol("w").fit(dataset) } intercept[IllegalArgumentException] { @@ -145,4 +143,55 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) } } + + test("vector features column with feature index") { + val dataset = sqlContext.createDataFrame(Seq( + (4.0, Vectors.dense(0.0, 1.0)), + (3.0, Vectors.dense(0.0, 2.0)), + (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + ).toDF("label", "features") + + val ir = new IsotonicRegression() + .setFeatureIndex(1) + + val model = ir.fit(dataset) + + val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 7cdda3db88ad..2f3e703f4c25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,22 +17,32 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ - @transient var datasetWithoutIntercept: DataFrame = _ + private val seed: Int = 42 + @transient var datasetWithDenseFeature: DataFrame = _ + @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ + @transient var datasetWithSparseFeature: DataFrame = _ + @transient var datasetWithWeight: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML is the same as the one trained by R's glmnet package. The following instruction describes how to reproduce the data in R. + In a spark-shell, use the following code: import org.apache.spark.mllib.util.LinearDataGenerator val data = @@ -43,17 +53,45 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { */ override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame( + datasetWithDenseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) /* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithoutIntercept = sqlContext.createDataFrame( + datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) + + val r = new Random(seed) + // When feature size is larger than 4096, normal optimizer is choosed + // as the solver of linear regression in the case of "auto" mode. + val featureSize = 4100 + datasetWithSparseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray, + xMean = Seq.fill(featureSize)(r.nextDouble).toArray, + xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, + seed, eps = 0.1, sparsity = 0.7), 2)) + + /* + R code: + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + datasetWithWeight = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -70,307 +108,808 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getRegParam === 0.0) assert(lir.getElasticNetParam === 0.0) assert(lir.getFitIntercept) - val model = lir.fit(dataset) - model.transform(dataset) + assert(lir.getStandardization) + assert(lir.getSolver == "auto") + val model = lir.fit(datasetWithDenseFeature) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetWithDenseFeature) .select("label", "prediction") .collect() assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") assert(model.intercept !== 0.0) assert(model.hasParent) + val numFeatures = datasetWithDenseFeature.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) } test("linear regression with intercept without regularization") { - val trainer = new LinearRegression - val model = trainer.fit(dataset) - - /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - label <- as.numeric(data$V1) - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.300528 - as.numeric.data.V2. 4.701024 - as.numeric.data.V3. 7.198257 - */ - val interceptR = 6.298698 - val weightsR = Vectors.dense(4.700706, 7.199082) - - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights ~= weightsR relTol 1E-3) - - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = new LinearRegression().setSolver(solver) + // The result should be the same regardless of standardization without regularization + val trainer2 = (new LinearRegression).setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.298698 + as.numeric.data.V2. 4.700706 + as.numeric.data.V3. 7.199082 + */ + val interceptR = 6.298698 + val coefficientsR = Vectors.dense(4.700706, 7.199082) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.coefficients ~= coefficientsR relTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) + + model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } test("linear regression without intercept without regularization") { - val trainer = (new LinearRegression).setFitIntercept(false) - val model = trainer.fit(dataset) - val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, - intercept = FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 - */ - val weightsR = Vectors.dense(6.995908, 5.275131) - - assert(model.intercept ~== 0 absTol 1E-3) - assert(model.weights ~= weightsR relTol 1E-3) - /* - Then again with the data with no intercept: - > weightsWithoutIntercept - 3 x 1 sparse Matrix of class "dgCMatrix" + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setFitIntercept(false).setSolver(solver) + // Without regularization the results should be the same + val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false) + .setSolver(solver) + val model1 = trainer1.fit(datasetWithDenseFeature) + val modelWithoutIntercept1 = trainer1.fit(datasetWithDenseFeatureWithoutIntercept) + val model2 = trainer2.fit(datasetWithDenseFeature) + val modelWithoutIntercept2 = trainer2.fit(datasetWithDenseFeatureWithoutIntercept) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.data3.V2. 4.70011 - as.numeric.data3.V3. 7.19943 - */ - val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) - - assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3) + (Intercept) . + as.numeric.data.V2. 6.973403 + as.numeric.data.V3. 5.284370 + */ + val coefficientsR = Vectors.dense(6.973403, 5.284370) + + assert(model1.intercept ~== 0 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) + assert(model2.intercept ~== 0 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) + + /* + Then again with the data with no intercept: + > coefficientsWithourIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 + */ + val coefficientsWithourInterceptR = Vectors.dense(4.70011, 7.19943) + + assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept1.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept2.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + } } test("linear regression with intercept with L1 regularization") { - val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - val model = trainer.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 - */ - val interceptR = 6.24300 - val weightsR = Vectors.dense(4.024821, 6.679841) - - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights ~= weightsR relTol 1E-3) - - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setSolver(solver).setStandardization(false) + + // Normal optimizer is not supported with only L1 regularization case. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(datasetWithDenseFeature) + trainer2.fit(datasetWithDenseFeature) + } + } else { + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", + alpha = 1.0, lambda = 0.57 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 + */ + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.416948 + as.numeric.data.V2. 3.893869 + as.numeric.data.V3. 6.724286 + */ + val interceptR2 = 6.416948 + val coefficientsR2 = Vectors.dense(3.893869, 6.724286) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression without intercept with L1 regularization") { - val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - .setFitIntercept(false) - val model = trainer.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - intercept=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 - */ - val interceptR = 0.0 - val weightsR = Vectors.dense(6.299752, 4.772913) - - assert(model.intercept ~== interceptR absTol 1E-5) - assert(model.weights ~= weightsR relTol 1E-3) - - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + + // Normal optimizer is not supported with only L1 regularization case. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(datasetWithDenseFeature) + trainer2.fit(datasetWithDenseFeature) + } + } else { + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression with intercept with L2 regularization") { - val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - val model = trainer.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.328062 - as.numeric.data.V2. 3.222034 - as.numeric.data.V3. 4.926260 - */ - val interceptR = 5.269376 - val weightsR = Vectors.dense(3.736216, 5.712356) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.260103 + as.numeric.d1.V2. 3.725522 + as.numeric.d1.V3. 5.711203 + */ + val interceptR1 = 5.260103 + val coefficientsR1 = Vectors.dense(3.725522, 5.711203) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + standardize=FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.790885 + as.numeric.d1.V2. 3.432373 + as.numeric.d1.V3. 5.919196 + */ + val interceptR2 = 5.790885 + val coefficientsR2 = Vectors.dense(3.432373, 5.919196) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + } - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights ~= weightsR relTol 1E-3) + test("linear regression without intercept with L2 regularization") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.493430 + as.numeric.d1.V3. 4.223082 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(5.493430, 4.223082) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE, standardize=FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.244324 + as.numeric.d1.V3. 4.203106 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(5.244324, 4.203106) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + } - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + test("linear regression with intercept with ElasticNet regularization") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setStandardization(false).setSolver(solver) + + // Normal optimizer is not supported with non-zero elasticnet parameter. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(datasetWithDenseFeature) + trainer2.fit(datasetWithDenseFeature) + } + } else { + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 + */ + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + standardize=FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 + */ + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } - test("linear regression without intercept with L2 regularization") { - val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - .setFitIntercept(false) - val model = trainer.fit(dataset) + test("linear regression without intercept with ElasticNet regularization") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + + // Normal optimizer is not supported with non-zero elasticnet parameter. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(datasetWithDenseFeature) + trainer2.fit(datasetWithDenseFeature) + } + } else { + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + } + } - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - intercept = FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 - */ - val interceptR = 0.0 - val weightsR = Vectors.dense(5.522875, 4.214502) + test("linear regression model training summary") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver) + val model = trainer.fit(datasetWithDenseFeature) + val trainerNoPredictionCol = trainer.setPredictionCol("") + val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature) + + // Training results for the model should be available + assert(model.hasSummary) + assert(modelNoPredictionCol.hasSummary) + + // Schema should be a superset of the input dataset + assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf( + model.summary.predictions.schema.fieldNames.toSet)) + // Validate that we re-insert a prediction column for evaluation + val modelNoPredictionColFieldNames + = modelNoPredictionCol.summary.predictions.schema.fieldNames + assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf( + modelNoPredictionColFieldNames.toSet)) + assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = datasetWithDenseFeature.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept + label - prediction + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } + + /* + # Use the following R code to generate model training results. + + # path/part-00000 is the file generated by running LinearDataGenerator.generateLinearInput + # as described before the beforeAll() method. + d1 <- read.csv("path/part-00000", header=FALSE, stringsAsFactors=FALSE) + fit <- glm(V1 ~ V2 + V3, data = d1, family = "gaussian") + names(f1)[1] = c("V2") + names(f1)[2] = c("V3") + f1 <- data.frame(as.numeric(d1$V2), as.numeric(d1$V3)) + predictions <- predict(fit, newdata=f1) + l1 <- as.numeric(d1$V1) + + residuals <- l1 - predictions + > mean(residuals^2) # MSE + [1] 0.00985449 + > mean(abs(residuals)) # MAD + [1] 0.07961668 + > cor(predictions, l1)^2 # r^2 + [1] 0.9998737 + + > summary(fit) + + Call: + glm(formula = V1 ~ V2 + V3, family = "gaussian", data = d1) + + Deviance Residuals: + Min 1Q Median 3Q Max + -0.47082 -0.06797 0.00002 0.06725 0.34635 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** + V2 4.6982442 0.0011805 3980 <2e-16 *** + V3 7.1994344 0.0009044 7961 <2e-16 *** + --- + + .... + */ + assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) + assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) + assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) + + // Normal solver uses "WeightedLeastSquares". This algorithm does not generate + // objective history because it does not run through iterations. + if (solver == "l-bfgs") { + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } else { + // To clalify that the normal solver is used here. + assert(model.summary.objectiveHistory.length == 1) + assert(model.summary.objectiveHistory(0) == 0.0) + val devianceResidualsR = Array(-0.47082, 0.34635) + val seCoefR = Array(0.0011805, 0.0009044, 0.0018600) + val tValsR = Array(3980, 7961, 3388) + val pValsR = Array(0, 0, 0) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-4) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-4) } + model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } + model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } + } + } + } - assert(model.intercept ~== interceptR absTol 1E-3) - assert(model.weights ~== weightsR relTol 1E-3) + test("linear regression model testset evaluation summary") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver) + val model = trainer.fit(datasetWithDenseFeature) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(datasetWithDenseFeature) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } + } - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + test("linear regression with weighted samples") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + + // Normal optimizer is not supported with non-zero elasticnet parameter. + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + + assert(model1a0.coefficients !~= model1a1.coefficients absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.coefficients !~= model2a1.coefficients absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.coefficients ~== model2b.coefficients absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.coefficients !~= model3a1.coefficients absTol 1E-3) + assert(model3a0.coefficients ~== model3b.coefficients absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.coefficients !~= model4a1.coefficients absTol 1E-3) + assert(model4a0.coefficients ~== model4b.coefficients absTol 1E-3) } } - test("linear regression with intercept with ElasticNet regularization") { - val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - val model = trainer.fit(dataset) + test("linear regression model with l-bfgs with big feature datasets") { + val trainer = new LinearRegression().setSolver("auto") + val model = trainer.fit(datasetWithSparseFeature) + // Training results for the model should be available + assert(model.hasSummary) + // When LBFGS is used as optimizer, objective history can be restored. + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("linear regression summary with weighted samples and intercept by normal solver") { /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 - */ - val interceptR = 5.696056 - val weightsR = Vectors.dense(3.670489, 6.001122) + R code: - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights ~== weightsR relTol 1E-3) + model <- glm(formula = "b ~ .", data = df, weights = w) + summary(model) - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } - } + Call: + glm(formula = "b ~ .", data = df, weights = w) - test("linear regression without intercept with ElasticNet regularization") { - val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - .setFitIntercept(false) - val model = trainer.fit(dataset) + Deviance Residuals: + 1 2 3 4 + 1.920 -1.358 -1.109 0.960 - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - intercept=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 18.080 9.608 1.882 0.311 + V1 6.080 5.556 1.094 0.471 + V2 -0.600 1.960 -0.306 0.811 + + (Dispersion parameter for gaussian family taken to be 7.68) + + Null deviance: 202.00 on 3 degrees of freedom + Residual deviance: 7.68 on 1 degrees of freedom + AIC: 18.783 + + Number of Fisher Scoring iterations: 2 */ - val interceptR = 0.0 - val weightsR = Vectors.dense(5.673348, 4.322251) + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(6.080, -0.600)) + val interceptR = 18.080 + val devianceResidualsR = Array(-1.358, 1.920) + val seCoefR = Array(5.556, 1.960, 9.608) + val tValsR = Array(1.094, -0.306, 1.882) + val pValsR = Array(0.471, 0.811, 0.311) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) assert(model.intercept ~== interceptR absTol 1E-3) - assert(model.weights ~= weightsR relTol 1E-3) - - model.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } - test("linear regression model training summary") { - val trainer = new LinearRegression - val model = trainer.fit(dataset) + test("linear regression summary with weighted samples and w/o intercept by normal solver") { + /* + R code: - // Training results for the model should be available - assert(model.hasSummary) + model <- glm(formula = "b ~ . -1", data = df, weights = w) + summary(model) - // Residuals in [[LinearRegressionResults]] should equal those manually computed - val expectedResiduals = dataset.select("features", "label") - .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - label - prediction - } - .zip(model.summary.residuals.map(_.getDouble(0))) - .collect() - .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + Call: + glm(formula = "b ~ . -1", data = df, weights = w) - /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 + Deviance Residuals: + 1 2 3 4 + 1.950 2.344 -4.600 2.103 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + V1 -3.7271 2.9032 -1.284 0.3279 + V2 3.0100 0.6022 4.998 0.0378 * + --- + Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 + + (Dispersion parameter for gaussian family taken to be 17.4376) + + Null deviance: 5962.000 on 4 degrees of freedom + Residual deviance: 34.875 on 2 degrees of freedom + AIC: 22.835 + + Number of Fisher Scoring iterations: 2 */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) - // Objective function should be monotonically decreasing for linear regression - assert( - model.summary - .objectiveHistory - .sliding(2) - .forall(x => x(0) >= x(1))) + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .setFitIntercept(false) + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(-3.7271, 3.0100)) + val interceptR = 0.0 + val devianceResidualsR = Array(-4.600, 2.344) + val seCoefR = Array(2.9032, 0.6022) + val tValsR = Array(-1.284, 4.998) + val pValsR = Array(0.3279, 0.0378) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept === interceptR) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } - test("linear regression model testset evaluation summary") { - val trainer = new LinearRegression - val model = trainer.fit(dataset) - - // Evaluating on training dataset should yield results summary equal to training summary - val testSummary = model.evaluate(dataset) - assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) - assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) - model.summary.residuals.select("residuals").collect() - .zip(testSummary.residuals.select("residuals").collect()) - .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) } +} + +object LinearRegressionSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index b24ecaa57c89..7e751e4b553b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -26,7 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestRegressor]]. */ @@ -71,6 +72,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("Feature importance with toy data") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -107,6 +137,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite { data: RDD[LabeledPoint], rf: RandomForestRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) val oldModel = OldRandomForest.trainRegressor( @@ -117,5 +148,6 @@ private object RandomForestRegressorSuite extends SparkFunSuite { val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala new file mode 100644 index 000000000000..997f574e51f6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var tempDir: File = _ + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + tempDir = Utils.createTempDir() + val file = new File(tempDir, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + + test("select as sparse vector") { + val df = sqlContext.read.format("libsvm").load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[DenseVector](1) + assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select a vector with specifying the longer dimension") { + val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + .load(path) + val row1 = df.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala new file mode 100644 index 000000000000..d5c238e9ae16 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestSuite.mapToVec + + test("computeFeatureImportance, featureImportances") { + /* Build tree for testing, with this structure: + grandParent + left2 parent + left right + */ + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parentImp = parent.impurityStats + + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandImp = grandParent.impurityStats + + // Test feature importance computed at different subtrees. + def testNode(node: Node, expected: Map[Int, Double]): Unit = { + val map = new OpenHashMap[Int, Double]() + RandomForest.computeFeatureImportance(node, map) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + // Leaf node + testNode(left, Map.empty[Int, Double]) + + // Internal node with 2 leaf children + val feature0importance = parentImp.calculate() * parentImp.count - + (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count) + testNode(parent, Map(0 -> feature0importance)) + + // Full tree + val feature1importance = grandImp.calculate() * grandImp.count - + (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count) + testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance)) + + // Forest consisting of (full tree) + (internal node with 2 leafs) + val trees = Array(parent, grandParent).map { root => + new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) + .asInstanceOf[DecisionTreeModel] + } + val importances: Vector = RandomForest.featureImportances(trees, 2) + val tree2norm = feature0importance + feature1importance + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + (feature1importance / tree2norm) / 2.0) + assert(importances ~== expected relTol 0.01) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + RandomForest.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + +} + +private object RandomForestSuite { + + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a7605..d281084f913c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,24 +18,27 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() - val sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } @@ -53,6 +56,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) @@ -64,7 +71,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - val trainer = new LinearRegression + val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() .addGrid(trainer.regParam, Array(1000.0, 0.001)) .addGrid(trainer.maxIter, Array(0, 10)) @@ -90,7 +97,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -111,9 +118,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] @@ -138,6 +330,8 @@ object CrossValidatorSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index c8e58f216cce..5fb80091d0b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -58,7 +58,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - val trainer = new LinearRegression + val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() .addGrid(trainer.regParam, Array(1000.0, 0.001)) .addGrid(trainer.maxIter, Array(0, 10)) @@ -132,6 +132,8 @@ object TrainValidationSplitSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala new file mode 100644 index 000000000000..84d06b43d622 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame + +trait DefaultReadWriteTest extends TempDirectory { self: Suite => + + /** + * Checks "overwrite" option and params. + * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name + * in order to avoid conflicts from multiple calls to this method. + * @param instance ML instance to test saving/loading + * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. + * @tparam T ML instance type + * @return Instance loaded from file + */ + def testDefaultReadWrite[T <: Params with MLWritable]( + instance: T, + testParams: Boolean = true): T = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("test") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.save(path) + intercept[IOException] { + instance.save(path) + } + instance.write.overwrite().save(path) + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] + val newInstance = loader.load(path) + + assert(newInstance.uid === instance.uid) + if (testParams) { + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") + } + } + } + + val load = instance.getClass.getMethod("load", classOf[String]) + val another = load.invoke(instance, path).asInstanceOf[T] + assert(another.uid === instance.uid) + another + } + + /** + * Default test for Estimator, Model pairs: + * - Explicitly set Params, and train model + * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Check Params on Estimator and Model + * + * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * @param estimator Estimator to test + * @param dataset Dataset to pass to [[Estimator.fit()]] + * @param testParams Set of [[Param]] values to set in estimator + * @param checkModelData Method which takes the original and loaded [[Model]] and compares their + * data. This method does not need to check [[Param]] values. + * @tparam E Type of [[Estimator]] + * @tparam M Type of [[Model]] produced by estimator + */ + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( + estimator: E, + dataset: DataFrame, + testParams: Map[String, Any], + checkModelData: (M, M) => Unit): Unit = { + // Set some Params to make sure set Params are serialized. + testParams.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + val model = estimator.fit(dataset) + + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator) + testParams.foreach { case (p, v) => + val param = estimator.getParam(p) + assert(estimator.get(param).get === estimator2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + testParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + } +} + +class MyParams(override val uid: String) extends Params with MLWritable { + + final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") + final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") + final val longParam: LongParam = new LongParam(this, "longParam", "doc") + final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") + final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") + final val doubleArrayParam: DoubleArrayParam = + new DoubleArrayParam(this, "doubleArrayParam", "doc") + final val stringArrayParam: StringArrayParam = + new StringArrayParam(this, "stringArrayParam", "doc") + + setDefault(intParamWithDefault -> 0) + set(intParam -> 1) + set(floatParam -> 2.0f) + set(doubleParam -> 3.0) + set(longParam -> 4L) + set(stringParam -> "5") + set(intArrayParam -> Array(6, 7)) + set(doubleArrayParam -> Array(8.0, 9.0)) + set(stringArrayParam -> Array("10", "11")) + + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override def write: MLWriter = new DefaultParamsWriter(this) +} + +object MyParams extends MLReadable[MyParams] { + + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] + + override def load(path: String): MyParams = super.load(path) +} + +class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("default read/write") { + val myParams = new MyParams("my_params") + testDefaultReadWrite(myParams) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 000000000000..d290cc9b06e7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala new file mode 100644 index 000000000000..c8a0bb16247b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.{BeforeAndAfterAll, Suite} + +import org.apache.spark.util.Utils + +/** + * Trait that creates a temporary directory before all tests and deletes it after all. + */ +trait TempDirectory extends BeforeAndAfterAll { self: Suite => + + private var _tempDir: File = _ + + /** Returns the temporary directory as a [[File]] instance. */ + protected def tempDir: File = _tempDir + + override def beforeAll(): Unit = { + super.beforeAll() + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(_tempDir) + super.afterAll() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 2473510e1351..8d14bb657215 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import scala.util.control.Breaks._ @@ -38,7 +38,7 @@ object LogisticRegressionSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + generateLogisticInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale*X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index b1d78cba9e3d..ee3c85d09a46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.jblas.DoubleMatrix @@ -35,7 +35,7 @@ object SVMSuite { weights: Array[Double], nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + generateSVMInput(intercept, weights, nPoints, seed).asJava } // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala new file mode 100644 index 000000000000..41b9d5c0d93b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("default values") { + val bkm0 = new BisectingKMeans() + assert(bkm0.getK === 4) + assert(bkm0.getMaxIterations === 20) + assert(bkm0.getMinDivisibleClusterSize === 1.0) + val bkm1 = new BisectingKMeans() + assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.") + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + + val k = 10 + assert(bkm.getK !== k) + assert(bkm.setK(k).getK === k) + val maxIter = 100 + assert(bkm.getMaxIterations !== maxIter) + assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter) + val minSize = 2.0 + assert(bkm.getMinDivisibleClusterSize !== minSize) + assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize) + val seed = 10L + assert(bkm.getSeed !== seed) + assert(bkm.setSeed(seed).getSeed === seed) + + intercept[IllegalArgumentException] { + bkm.setK(0) + } + intercept[IllegalArgumentException] { + bkm.setMaxIterations(0) + } + intercept[IllegalArgumentException] { + bkm.setMinDivisibleClusterSize(0.0) + } + } + + test("1D data") { + val points = Vectors.sparse(1, Array.empty, Array.empty) +: + (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(1) + .setSeed(1L) + // The clusters should be + // (0, 1, 2, 3, 4, 5, 6, 7) + // - (0, 1, 2, 3) + // - (0, 1) + // - (2, 3) + // - (4, 5, 6, 7) + // - (4, 5) + // - (6, 7) + val model = bkm.run(data) + assert(model.k === 4) + // The total cost should be 8 * 0.5 * 0.5 = 2.0. + assert(model.computeCost(data) ~== 2.0 relTol 1e-12) + val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap() + Range(0, 8, 2).foreach { i => + assert(predictions(i) === predictions(i + 1), + s"$i and ${i + 1} should belong to the same cluster.") + } + val root = model.root + assert(root.center(0) ~== 3.5 relTol 1e-12) + assert(root.height ~== 2.0 relTol 1e-12) + assert(root.children.length === 2) + assert(root.children(0).height ~== 1.0 relTol 1e-12) + assert(root.children(1).height ~== 1.0 relTol 1e-12) + } + + test("points are the same") { + val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2) + val bkm = new BisectingKMeans() + .setK(2) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 1) + } + + test("more desired clusters than points") { + val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2) + val bkm = new BisectingKMeans() + .setK(8) + .setMaxIterations(2) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 4) + } + + test("min divisible cluster") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMinDivisibleClusterSize(10) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + + bkm.setMinDivisibleClusterSize(0.5) + val sameModel = bkm.run(data) + assert(sameModel.k === 3) + } + + test("larger clusters get selected first") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + } + + test("2D data") { + val points = Seq( + (11, 10), (9, 10), (10, 9), (10, 11), + (11, -10), (9, -10), (10, -9), (10, -11), + (0, 1), (0, -1) + ).map { case (x, y) => + if (x == 0) { + Vectors.sparse(2, Array(1), Array(y)) + } else { + Vectors.dense(x, y) + } + } + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(4) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12) + model.root.leafNodes.foreach { node => + if (node.center(0) < 5) { + assert(node.size === 2) + assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12) + } else if (node.center(1) > 0) { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12) + } else { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b218d72f1268..a72723eb00da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -76,6 +76,20 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + test("two clusters with distributed decompositions") { + val data = sc.parallelize(GaussianTestData.data2, 2) + + val k = 5 + val d = data.first().size + assert(GaussianMixture.shouldDistributeGaussians(k, d)) + + val gmm = new GaussianMixture() + .setK(k) + .run(data) + + assert(gmm.k === k) + } + test("single cluster with sparse data") { val data = sc.parallelize(Array( Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), @@ -116,7 +130,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { val sparseGMM = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) - .run(data) + .run(sparseData) assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) @@ -148,6 +162,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + object GaussianTestData { val data = Array( @@ -158,5 +182,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) ) + val data2: Array[Vector] = Array.tabulate(25){ i: Int => + Vectors.dense(Array.tabulate(50)(i + _.toDouble)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index fdc2554ab853..37fb69d68f6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.{ArrayList => JArrayList} + import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite @@ -66,6 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Train a model val lda = new LDA() lda.setK(k) + .setOptimizer(new EMLDAOptimizer) .setDocConcentration(topicSmoothing) .setTopicConcentration(termSmoothing) .setMaxIterations(5) @@ -133,17 +136,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } // Top 3 documents per topic - model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } // All documents per topic val q = tinyCorpus.length - model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } + + // Check: topTopicAssignments + // Make sure it assigns a topic to each term appearing in each doc. + val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] = + model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap + assert(topTopicAssignments.keys.max < tinyCorpus.length) + tinyCorpus.foreach { case (docID: Long, doc: Vector) => + if (topTopicAssignments.contains(docID)) { + val (inds, vals) = topTopicAssignments(docID) + assert(inds.length === doc.numNonzeros) + // For "term" in actual doc, + // check that it has a topic assigned. + doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term))) + } else { + assert(doc.numNonzeros === 0) + } + } } test("vertex indexing") { @@ -160,8 +180,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha.toArray.forall(_ === 2.0)) - assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } @@ -404,7 +424,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val k = 2 val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) - .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false) val lda = new LDA().setK(k) .setDocConcentration(1D / k) .setTopicConcentration(0.01) @@ -575,6 +595,17 @@ private[clustering] object LDASuite { Vectors.sparse(6, Array(4, 5), Array(1, 1)) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.length) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + def toyModel: LocalLDAModel = { val k = 2 val vocabSize = 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 3645d29dccdb..65e37c64d404 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -98,9 +98,16 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers - // NOTE exact assignment depends on the initialization! - assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) - assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + // cluster ordering is arbitrary, so choose closest cluster + val d0 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(0)) + val d1 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(1)) + val (c0, c1) = if (d0 < d1) { + (centers(0), centers(1)) + } else { + (centers(1), centers(0)) + } + assert(c0 ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(c1 ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) } test("detecting dying clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 889727fb5582..734800a9afad 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { }.collect().toSet assert(filteredData == preFilteredData) } + + test("model load / save") { + val model = ChiSqSelectorSuite.createModel() + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val sameModel = ChiSqSelectorModel.load(sc, path) + ChiSqSelectorSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object ChiSqSelectorSuite extends SparkFunSuite { + + def createModel(): ChiSqSelectorModel = { + val arr = Array(1, 2, 3, 4) + new ChiSqSelectorModel(arr) + } + + def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = { + assert(a.selectedFeatures.deep == b.selectedFeatures.deep) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index e57f49191378..a8d82932d390 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -37,11 +37,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { val pca = new PCA(k).fit(dataRDD) val mat = new RowMatrix(dataRDD) - val pc = mat.computePrincipalComponents(k) + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) val pca_transform = pca.transform(dataRDD).collect() val mat_multiply = mat.multiply(pc).rows.collect() assert(pca_transform.toSet === mat_multiply.toSet) + assert(pca.explainedVariance === explainedVariance) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index a864eec460f2..37d01e287669 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("big model load / save") { + // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 + val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 0ae48d62cc6b..a83e543859b8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { - test("PrefixSpan internal (integer seq, -1 delim) run, singleton itemsets") { + test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") { /* library("arulesSequences") @@ -35,83 +35,81 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { */ val sequences = Array( - Array(1, -1, 3, -1, 4, -1, 5), - Array(2, -1, 3, -1, 1), - Array(2, -1, 4, -1, 1), - Array(3, -1, 1, -1, 3, -1, 4, -1, 5), - Array(3, -1, 4, -1, 4, -1, 3), - Array(6, -1, 5, -1, 3)) + Array(0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 2, 0, 3, 0, 1, 0), + Array(0, 2, 0, 4, 0, 1, 0), + Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 3, 0, 4, 0, 4, 0, 3, 0), + Array(0, 6, 0, 5, 0, 3, 0)) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() - .setMinSupport(0.33) - .setMaxPatternLength(50) - val result1 = prefixspan.run(rdd) + val result1 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L) val expectedValue1 = Array( - (Array(1), 4L), - (Array(1, -1, 3), 2L), - (Array(1, -1, 3, -1, 4), 2L), - (Array(1, -1, 3, -1, 4, -1, 5), 2L), - (Array(1, -1, 3, -1, 5), 2L), - (Array(1, -1, 4), 2L), - (Array(1, -1, 4, -1, 5), 2L), - (Array(1, -1, 5), 2L), - (Array(2), 2L), - (Array(2, -1, 1), 2L), - (Array(3), 5L), - (Array(3, -1, 1), 2L), - (Array(3, -1, 3), 2L), - (Array(3, -1, 4), 3L), - (Array(3, -1, 4, -1, 5), 2L), - (Array(3, -1, 5), 2L), - (Array(4), 4L), - (Array(4, -1, 5), 2L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 3, 0, 5, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue1, result1.collect()) - prefixspan.setMinSupport(0.5).setMaxPatternLength(50) - val result2 = prefixspan.run(rdd) + val result2 = PrefixSpan.genFreqPatterns( + rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L) val expectedValue2 = Array( - (Array(1), 4L), - (Array(3), 5L), - (Array(3, -1, 4), 3L), - (Array(4), 4L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 4, 0), 4L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue2, result2.collect()) - prefixspan.setMinSupport(0.33).setMaxPatternLength(2) - val result3 = prefixspan.run(rdd) + val result3 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L) val expectedValue3 = Array( - (Array(1), 4L), - (Array(1, -1, 3), 2L), - (Array(1, -1, 4), 2L), - (Array(1, -1, 5), 2L), - (Array(2, -1, 1), 2L), - (Array(2), 2L), - (Array(3), 5L), - (Array(3, -1, 1), 2L), - (Array(3, -1, 3), 2L), - (Array(3, -1, 4), 3L), - (Array(3, -1, 5), 2L), - (Array(4), 4L), - (Array(4, -1, 5), 2L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue3, result3.collect()) } test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") { val sequences = Array( - Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6), - Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5), - Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2), - Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3)) + Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0), + Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0), + Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0), + Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0)) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5) - val result = prefixspan.run(rdd) + val result = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L) /* To verify results, create file "prefixSpanSeqs" with content @@ -200,63 +198,87 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { 53 <{1},{2},{1}> 0.50 */ val expectedValue = Array( - (Array(1), 4L), - (Array(2), 4L), - (Array(3), 4L), - (Array(4), 3L), - (Array(5), 3L), - (Array(6), 3L), - (Array(1, -1, 6), 2L), - (Array(2, -1, 6), 2L), - (Array(5, -1, 6), 2L), - (Array(1, 2, -1, 6), 2L), - (Array(1, -1, 4), 2L), - (Array(2, -1, 4), 2L), - (Array(1, 2, -1, 4), 2L), - (Array(1, -1, 3), 4L), - (Array(2, -1, 3), 3L), - (Array(2, 3), 2L), - (Array(3, -1, 3), 3L), - (Array(4, -1, 3), 3L), - (Array(5, -1, 3), 2L), - (Array(6, -1, 3), 2L), - (Array(5, -1, 6, -1, 3), 2L), - (Array(6, -1, 2, -1, 3), 2L), - (Array(5, -1, 2, -1, 3), 2L), - (Array(5, -1, 1, -1, 3), 2L), - (Array(2, -1, 4, -1, 3), 2L), - (Array(1, -1, 4, -1, 3), 2L), - (Array(1, 2, -1, 4, -1, 3), 2L), - (Array(1, -1, 3, -1, 3), 3L), - (Array(1, 2, -1, 3), 2L), - (Array(1, -1, 2, -1, 3), 2L), - (Array(1, -1, 2, 3), 2L), - (Array(1, -1, 2), 4L), - (Array(1, 2), 2L), - (Array(3, -1, 2), 3L), - (Array(4, -1, 2), 2L), - (Array(5, -1, 2), 2L), - (Array(6, -1, 2), 2L), - (Array(5, -1, 6, -1, 2), 2L), - (Array(6, -1, 3, -1, 2), 2L), - (Array(5, -1, 3, -1, 2), 2L), - (Array(5, -1, 1, -1, 2), 2L), - (Array(4, -1, 3, -1, 2), 2L), - (Array(1, -1, 3, -1, 2), 3L), - (Array(5, -1, 6, -1, 3, -1, 2), 2L), - (Array(5, -1, 1, -1, 3, -1, 2), 2L), - (Array(1, -1, 1), 2L), - (Array(2, -1, 1), 2L), - (Array(3, -1, 1), 2L), - (Array(5, -1, 1), 2L), - (Array(2, 3, -1, 1), 2L), - (Array(1, -1, 3, -1, 1), 2L), - (Array(1, -1, 2, 3, -1, 1), 2L), - (Array(1, -1, 2, -1, 1), 2L)) + (Array(0, 1, 0), 4L), + (Array(0, 2, 0), 4L), + (Array(0, 3, 0), 4L), + (Array(0, 4, 0), 3L), + (Array(0, 5, 0), 3L), + (Array(0, 6, 0), 3L), + (Array(0, 1, 0, 6, 0), 2L), + (Array(0, 2, 0, 6, 0), 2L), + (Array(0, 5, 0, 6, 0), 2L), + (Array(0, 1, 2, 0, 6, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 2, 0, 4, 0), 2L), + (Array(0, 1, 2, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0), 4L), + (Array(0, 2, 0, 3, 0), 3L), + (Array(0, 2, 3, 0), 2L), + (Array(0, 3, 0, 3, 0), 3L), + (Array(0, 4, 0, 3, 0), 3L), + (Array(0, 5, 0, 3, 0), 2L), + (Array(0, 6, 0, 3, 0), 2L), + (Array(0, 5, 0, 6, 0, 3, 0), 2L), + (Array(0, 6, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0), 2L), + (Array(0, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 3, 0), 3L), + (Array(0, 1, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 3, 0), 2L), + (Array(0, 1, 0, 2, 0), 4L), + (Array(0, 1, 2, 0), 2L), + (Array(0, 3, 0, 2, 0), 3L), + (Array(0, 4, 0, 2, 0), 2L), + (Array(0, 5, 0, 2, 0), 2L), + (Array(0, 6, 0, 2, 0), 2L), + (Array(0, 5, 0, 6, 0, 2, 0), 2L), + (Array(0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 2, 0), 2L), + (Array(0, 4, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 3, 0, 2, 0), 3L), + (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 1, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 5, 0, 1, 0), 2L), + (Array(0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 0, 1, 0), 2L)) compareInternalResults(expectedValue, result.collect()) } + test("PrefixSpan projections with multiple partial starts") { + val sequences = Seq( + Array(Array(1, 2), Array(1, 2, 3))) + val rdd = sc.parallelize(sequences, 2) + val prefixSpan = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + val model = prefixSpan.run(rdd) + val expected = Array( + (Array(Array(1)), 1L), + (Array(Array(1, 2)), 1L), + (Array(Array(1), Array(1)), 1L), + (Array(Array(1), Array(2)), 1L), + (Array(Array(1), Array(3)), 1L), + (Array(Array(1, 3)), 1L), + (Array(Array(2)), 1L), + (Array(Array(2, 3)), 1L), + (Array(Array(2), Array(1)), 1L), + (Array(Array(2), Array(2)), 1L), + (Array(Array(2), Array(3)), 1L), + (Array(Array(3)), 1L)) + compareResults(expected, model.freqSequences.collect()) + } + test("PrefixSpan Integer type, variable-size itemsets") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -265,7 +287,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(Array(6))) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() + val prefixSpan = new PrefixSpan() .setMinSupport(0.5) .setMaxPatternLength(5) @@ -296,7 +318,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { 5 <{1,2}> 0.75 */ - val model = prefixspan.run(rdd) + val model = prefixSpan.run(rdd) val expected = Array( (Array(Array(1)), 3L), (Array(Array(2)), 3L), @@ -304,7 +326,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(Array(1), Array(3)), 2L), (Array(Array(1, 2)), 3L) ) - compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq))) + compareResults(expected, model.freqSequences.collect()) } test("PrefixSpan String type, variable-size itemsets") { @@ -318,11 +340,11 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString))) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() + val prefixSpan = new PrefixSpan() .setMinSupport(0.5) .setMaxPatternLength(5) - val model = prefixspan.run(rdd) + val model = prefixSpan.run(rdd) val expected = Array( (Array(Array(1)), 3L), (Array(Array(2)), 3L), @@ -332,17 +354,17 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { ).map { case (pattern, count) => (pattern.map(itemSet => itemSet.map(intToString)), count) } - compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq))) + compareResults(expected, model.freqSequences.collect()) } private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], - actualValue: Array[(Array[Array[Item]], Long)]): Unit = { + actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) => (pattern.map(itemSet => itemSet.toSet).toSeq, count) }.toSet - val actualSet = actualValue.map { case (pattern: Array[Array[Item]], count: Long) => - (pattern.map(itemSet => itemSet.toSet).toSeq, count) + val actualSet = actualValue.map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) }.toSet assert(expectedSet === actualSet) } @@ -354,11 +376,4 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } - - private def insertDelimiter(sequence: Array[Int]): Array[Int] = { - sequence.zip(Seq.fill(sequence.length)(PrefixSpan.DELIMITER)).map { case (a, b) => - List(a, b) - }.flatten - } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index d119e0b50a39..96e5ffef7a13 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite { } } + test("spr") { + // test dense vector + val alpha = 0.1 + val x = new DenseVector(Array(1.0, 2, 2.1, 4)) + val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6)) + + spr(alpha, x, U) + assert(U ~== expected absTol 1e-9) + + val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5)) + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + spr(alpha, x, matrix33) + } + } + + // test sparse vector + val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2)) + val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + spr(0.1, sv, U2) + val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4)) + assert(U2 ~== expectedSparse absTol 1e-15) + } + test("syr") { val dA = new DenseMatrix(4, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) @@ -204,6 +229,7 @@ class BLASSuite extends SparkFunSuite { val C14 = C1.copy val C15 = C1.copy val C16 = C1.copy + val C17 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) @@ -217,6 +243,10 @@ class BLASSuite extends SparkFunSuite { assert(C2 ~== expected2 absTol 1e-15) assert(C3 ~== expected3 absTol 1e-15) assert(C4 ~== expected3 absTol 1e-15) + gemm(1.0, dA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) + gemm(1.0, sA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a270ba2562db..1833cf383367 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -74,6 +74,35 @@ class MatricesSuite extends SparkFunSuite { } } + test("index in matrices incorrect input") { + val sm = Matrices.sparse(3, 2, Array(0, 2, 3), Array(1, 2, 1), Array(0.0, 1.0, 2.0)) + val dm = Matrices.dense(3, 2, Array(0.0, 2.3, 1.4, 3.2, 1.0, 9.1)) + Array(sm, dm).foreach { mat => + intercept[IllegalArgumentException] { mat.index(4, 1) } + intercept[IllegalArgumentException] { mat.index(1, 4) } + intercept[IllegalArgumentException] { mat.index(-1, 2) } + intercept[IllegalArgumentException] { mat.index(1, -2) } + } + } + + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 1c37ea5123e8..f895e2a8e4af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import org.json4s.jackson.JsonMethods.{parse => parseJson} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ @@ -367,4 +368,27 @@ class VectorsSuite extends SparkFunSuite with Logging { val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) } + + test("SparseVector.slice") { + val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4)) + assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2))) + assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) + assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) + } + + test("toJson/fromJson") { + val sv0 = Vectors.sparse(0, Array.empty, Array.empty) + val sv1 = Vectors.sparse(1, Array.empty, Array.empty) + val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0) + val dv2 = Vectors.dense(0.0, 2.0) + for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) { + val json = v.toJson + parseJson(json) // `json` should be a valid JSON string + val u = Vectors.fromJson(json) + assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.") + assert(u === v, "toJson/fromJson should preserve vector values.") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 93fe04c139b9..b8eb10305801 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -235,6 +235,24 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(localC ~== result absTol 1e-8) } + test("simulate multiply") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks, + math.max(numPartitions, 2)) + val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner) + assert(destinationsA((0, 0)) === Set(0)) + assert(destinationsA((0, 1)) === Set(2)) + assert(destinationsA((1, 0)) === Set(0)) + assert(destinationsA((1, 1)) === Set(2)) + assert(destinationsA((2, 1)) === Set(3)) + assert(destinationsB((0, 0)) === Set(0)) + assert(destinationsB((1, 1)) === Set(2, 3)) + } + test("validate") { // No error gridBasedMat.validate() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 0ecb7a221a50..6de6cf2fa863 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -153,6 +153,18 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("similar columns") { + val A = new IndexedRowMatrix(indexedRows) + val gram = A.computeGramianMatrix().toBreeze.toDenseMatrix + + val G = A.columnSimilarities().toBreeze() + + for (i <- 0 until n; j <- i + 1 until n) { + val trueResult = gram(i, j) / scala.math.sqrt(gram(i, i) * gram(j, j)) + assert(math.abs(G(i, j) - trueResult) < 1e-6) + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 283ffec1d49d..0ff901ddc497 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.linalg.distributed +import java.util.Arrays + import scala.util.Random import breeze.numerics.abs @@ -24,6 +26,7 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -48,6 +51,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { (0.0, 1.0, 0.0), (math.sqrt(2.0) / 2.0, 0.0, math.sqrt(2.0) / 2.0), (math.sqrt(2.0) / 2.0, 0.0, - math.sqrt(2.0) / 2.0)) + val explainedVariance = BDV(4.0 / 7.0, 3.0 / 7.0, 0.0) var denseMat: RowMatrix = _ var sparseMat: RowMatrix = _ @@ -200,10 +204,15 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { test("pca") { for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) { - val pc = denseMat.computePrincipalComponents(k) + val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) assert(pc.numRows === n) assert(pc.numCols === k) assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k) + assert( + closeToZero(BDV(expVariance.toArray) - + BDV(Arrays.copyOfRange(explainedVariance.data, 0, k)))) + // Check that this method returns the same answer + assert(pc === mat.computePrincipalComponents(k)) } } @@ -255,6 +264,23 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) } } + + test("compute covariance") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.computeCovariance() + val expected = breeze.linalg.cov(mat.toBreeze()) + assert(closeToZero(abs(expected) - abs(result.toBreeze.asInstanceOf[BDM[Double]]))) + } + } + + test("covariance matrix is symmetric (SPARK-10875)") { + val rdd = RandomRDDs.normalVectorRDD(sc, 100, 10, 0, 0) + val matrix = new RowMatrix(rdd) + val cov = matrix.computeCovariance() + for (i <- 0 until cov.numRows; j <- 0 until i) { + assert(cov(i, j) === cov(j, i)) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 13b754a03943..36ac7d267243 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.optimization -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.scalatest.Matchers @@ -35,7 +35,7 @@ object GradientDescentSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateGDInput(offset, scale, nPoints, seed)) + generateGDInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale * X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index a5ca1518f82f..8416771552fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.random -import scala.math +import org.apache.commons.math3.special.Gamma import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite { distributionChecks(gamma, expectedMean, expectedStd, 0.1) } } + + test("WeibullGenerator") { + List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map { + case (alpha: Double, beta: Double) => + val weibull = new WeibullGenerator(alpha, beta) + apiChecks(weibull) + + val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta + val expectedVariance = math.exp( + Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean + val expectedStd = math.sqrt(expectedVariance) + distributionChecks(weibull, expectedMean, expectedStd, 0.1) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index bc6417261483..ac93733bab5f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { - val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList - val expected = data.sliding(windowSize).map(_.toList).toList - assert(sliding === expected) + for (step <- 1 to 3) { + val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList + val expected = data.sliding(windowSize, step) + .map(_.toList).toList.filter(l => l.size == windowSize) + assert(sliding === expected) + } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") @@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) - assert(rdd.partitions.size === data.length) + assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 05b87728d6fd..045135f7f8d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.recommendation -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.math.abs import scala.util.Random @@ -38,7 +38,7 @@ object ALSSuite { negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { val (sampledRatings, trueRatings, truePrefs) = generateRatings(users, products, features, samplingRate, implicitPrefs) - (seqAsJavaList(sampledRatings), trueRatings, truePrefs) + (sampledRatings.asJava, trueRatings, truePrefs) } def generateRatings( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 07efde4f5e6d..b6d41db69be0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { s0.merge(s1) assert(s0.mean(0) ~== 1.0 absTol 1e-14) } + + test("merging summarizer with weighted samples") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 000000000000..3c657c8cfe74 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, + WelchTTest, BinarySample} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_.isExperiment)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + BinarySample(false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index aa60deb665ae..6e7a00347545 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } + + test("SPARK-11302") { + val x = Vectors.dense(629, 640, 1.7188, 618.19) + val mu = Vectors.dense( + 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) + val sigma = Matrices.dense(4, 4, Array( + 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, + 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, + 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, + 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) + val dist = new MultivariateGaussian(mu, sigma) + // Agrees with R's dmvnorm: 7.154782e-05 + assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 356d957f1590..bf8fe1acac2f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 3) - assert(fakeMetadata.numSplits(0) === 3) - assert(fakeMetadata.numBins(0) === 4) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 2) - assert(fakeMetadata.numSplits(0) === 2) - assert(fakeMetadata.numBins(0) === 3) assert(splits(0) === 2.0) assert(splits(1) === 3.0) } @@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 1) - assert(fakeMetadata.numSplits(0) === 1) - assert(fakeMetadata.numBins(0) === 2) assert(splits(0) === 1.0) } } @@ -184,7 +178,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) @@ -243,7 +237,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -427,7 +421,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -461,7 +455,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -490,7 +484,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -794,7 +788,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, - maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) val rootNode = DecisionTree.train(rdd, strategy).topNode diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 334bf3790fc7..3d3f80063f90 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -69,8 +69,8 @@ object EnsembleTestHelper { required: Double, metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => - label - prediction + val errors = predictions.zip(input).map { case (prediction, point) => + point.label - prediction } val metric = metricName match { case "mse" => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 49aff21fe791..14152cdd63bc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} -import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 5d1796ef6572..378139593b26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -32,11 +32,14 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + SQLContext.clearActive() sqlContext = new SQLContext(sc) + SQLContext.setActive(sqlContext) } override def afterAll() { sqlContext = null + SQLContext.clearActive() if (sc != null) { sc.stop() } diff --git a/network/common/pom.xml b/network/common/pom.xml index 7dc3068ab8cb..9af6cc5e925f 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -48,6 +48,10 @@ slf4j-api provided
      + + com.google.code.findbugs + jsr305 + - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j test + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.mockito mockito-core diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index b8d073fa16b4..238710d17249 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -39,6 +39,7 @@ import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.TransportFrameDecoder; /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to @@ -58,15 +59,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -84,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ @@ -119,7 +135,7 @@ public TransportChannelHandler initializePipeline( TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", encoder) - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this @@ -143,7 +159,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d79542..47e93f9846fa 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,15 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { /** Successful serialized result from server. */ - void onSuccess(byte[] response); + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java new file mode 100644 index 000000000000..51d34cac6e63 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be + * called. + *

      + * The network library guarantees that a single thread will call these methods at a time, but + * different call may be made by different threads. + */ +public interface StreamCallback { + /** Called upon receipt of stream data. */ + void onData(String streamId, ByteBuffer buf) throws IOException; + + /** Called when all data from the stream has been received. */ + void onComplete(String streamId) throws IOException; + + /** Called if there's an error reading data from the stream. */ + void onFailure(String streamId, Throwable cause) throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java new file mode 100644 index 000000000000..88ba3ccebdf2 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.util.TransportFrameDecoder; + +/** + * An interceptor that is registered with the frame decoder to feed stream data to a + * callback. + */ +class StreamInterceptor implements TransportFrameDecoder.Interceptor { + + private final TransportResponseHandler handler; + private final String streamId; + private final long byteCount; + private final StreamCallback callback; + + private volatile long bytesRead; + + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; + this.streamId = streamId; + this.byteCount = byteCount; + this.callback = callback; + this.bytesRead = 0; + } + + @Override + public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); + callback.onFailure(streamId, cause); + } + + @Override + public void channelInactive() throws Exception { + handler.deactivateStream(); + callback.onFailure(streamId, new ClosedChannelException()); + } + + @Override + public boolean handle(ByteBuf buf) throws Exception { + int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); + ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer(); + + int available = nioBuffer.remaining(); + callback.onData(streamId, nioBuffer); + bytesRead += available; + if (bytesRead > byteCount) { + RuntimeException re = new IllegalStateException(String.format( + "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); + callback.onFailure(streamId, re); + handler.deactivateStream(); + throw re; + } else if (bytesRead == byteCount) { + handler.deactivateStream(); + callback.onComplete(streamId); + } + + return bytesRead != byteCount; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index e8e7f06247d3..c49ca4d5ee92 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,10 +20,13 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -34,9 +37,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamRequest; import org.apache.spark.network.util.NettyUtils; /** @@ -70,20 +76,46 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; + @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; + } + + public Channel getChannel() { + return channel; } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { return channel.remoteAddress(); } + /** + * Returns the ID used by the client to authenticate itself when authentication is enabled. + * + * @return The client ID, or null if authentication is disabled. + */ + public String getClientId() { + return clientId; + } + + /** + * Sets the authenticated client ID. This is meant to be used by the authentication layer. + * + * Trying to set a different client ID after it's been set will result in an exception. + */ + public void setClientId(String id) { + Preconditions.checkState(clientId == null, "Client ID has already been set."); + this.clientId = id; + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * @@ -134,11 +166,55 @@ public void operationComplete(ChannelFuture future) throws Exception { }); } + /** + * Request to stream the data with the given stream ID from the remote end. + * + * @param streamId The stream to fetch. + * @param callback Object to call with the stream data. + */ + public void stream(final String streamId, final StreamCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.debug("Sending stream request for {} to {}", streamId, serverAddr); + + // Need to synchronize here so that the callback is added to the queue and the RPC is + // written to the socket atomically, so that callbacks are called in the right order + // when responses arrive. + synchronized (this) { + handler.addStreamCallback(callback); + channel.writeAndFlush(new StreamRequest(streamId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, + timeTaken); + } else { + String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + callback.onFailure(streamId, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + }); + } + } + /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -146,7 +222,7 @@ public void sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -167,18 +243,20 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + + return requestId; } /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { result.set(response); } @@ -197,6 +275,35 @@ public void onFailure(Throwable e) { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe @@ -207,6 +314,7 @@ public void close() { public String toString() { return Objects.toStringHelper(this) .add("remoteAdress", channel.remoteAddress()) + .add("clientId", clientId) .add("isActive", isActive()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8..61bafc838004 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. @@ -158,6 +169,18 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 94fc21af5e60..23a8dba59344 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -19,9 +19,12 @@ import java.io.IOException; import java.util.Map; +import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,8 +35,11 @@ import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportFrameDecoder; /** * Handler that processes server responses, in response to requests issued from a @@ -50,6 +56,9 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + private final Queue streamCallbacks; + private volatile boolean streamActive; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -57,11 +66,12 @@ public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.streamCallbacks = new ConcurrentLinkedQueue(); this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -70,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -78,6 +88,16 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } + public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); + streamCallbacks.offer(callback); + } + + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -116,7 +136,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -124,11 +144,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.buffer.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); - resp.buffer.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -146,10 +166,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; @@ -161,6 +185,44 @@ public void handle(ResponseMessage message) { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof StreamResponse) { + StreamResponse resp = (StreamResponse) message; + StreamCallback callback = streamCallbacks.poll(); + if (callback != null) { + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } else { + logger.error("Could not find callback for StreamResponse."); + } + } else if (message instanceof StreamFailure) { + StreamFailure resp = (StreamFailure) message; + StreamCallback callback = streamCallbacks.poll(); + if (callback != null) { + try { + callback.onFailure(resp.streamId, new RuntimeException(resp.error)); + } catch (IOException ioe) { + logger.warn("Error in stream failure handler.", ioe); + } + } else { + logger.warn("Stream failure with unknown callback: {}", resp.error); + } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } @@ -168,11 +230,18 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 000000000000..2924218c2f08 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java new file mode 100644 index 000000000000..c362c92fc4f5 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for response messages. + */ +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { + + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); + } + + public abstract ResponseMessage createFailureResponse(String error); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f76bb49e874f..7b28a9a96948 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; @@ -52,6 +52,11 @@ public static ChunkFetchFailure decode(ByteBuf buf) { return new ChunkFetchFailure(streamChunkId, errorString); } + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, errorString); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchFailure) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 980947cf13f6..26d063feb5fe 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { @@ -48,6 +48,11 @@ public static ChunkFetchRequest decode(ByteBuf buf) { return new ChunkFetchRequest(StreamChunkId.decode(buf)); } + @Override + public int hashCode() { + return streamChunkId.hashCode(); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchRequest) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index ff4936470c69..94c2ac9b20e4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,13 +30,12 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess implements ResponseMessage { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; - public final ManagedBuffer buffer; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + super(buffer, true); this.streamChunkId = streamChunkId; - this.buffer = buffer; } @Override @@ -53,6 +52,11 @@ public void encode(ByteBuf buf) { streamChunkId.encode(buf); } + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ public static ChunkFetchSuccess decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); @@ -61,11 +65,16 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { return new ChunkFetchSuccess(streamChunkId, managedBuf); } + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, body()); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -74,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", buffer) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d568370125fd..66f5b8b3a59c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,15 +19,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), - RpcRequest(3), RpcResponse(4), RpcFailure(5); + RpcRequest(3), RpcResponse(4), RpcFailure(5), + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9), User(-1); private final byte id; @@ -51,6 +61,11 @@ public static Type decode(ByteBuf buf) { case 3: return RpcRequest; case 4: return RpcResponse; case 5: return RpcFailure; + case 6: return StreamRequest; + case 7: return StreamResponse; + case 8: return StreamFailure; + case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 81f8d7f96350..074780f2b95c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,18 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + + case StreamRequest: + return StreamRequest.decode(in); + + case StreamResponse: + return StreamResponse.decode(in); + + case StreamFailure: + return StreamFailure.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 0f999f5dfe8d..abca22347b78 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -42,30 +42,38 @@ public final class MessageEncoder extends MessageToMessageEncoder { * data to 'out', in order to enable zero-copy transfer. */ @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { Object body = null; long bodyLength = 0; + boolean isBodyInFrame = false; - // Only ChunkFetchSuccesses have data besides the header. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ChunkFetchSuccess) { - ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { try { - bodyLength = resp.buffer.size(); - body = resp.buffer.convertToNetty(); + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { - // Re-encode this message as BlockFetchFailure. - logger.error(String.format("Error opening block %s for client %s", - resp.streamChunkId, ctx.channel().remoteAddress()), e); - encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } return; } } Message.Type msgType = in.type(); - // All messages have the frame length, message type, and message itself. + // All messages have the frame length, message type, and message itself. The frame length + // may optionally include the length of the body data, depending on what message is being + // sent. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); - long frameLength = headerLength + bodyLength; + long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 000000000000..efe0470f3587 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage extends AbstractMessage implements RequestMessage { + + public OneWayMessage(ManagedBuffer body) { + super(body, true); + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; + } + + @Override + public void encode(ByteBuf buf) { + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + public static OneWayMessage decode(ByteBuf buf) { + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("body", body()) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 6b991375fc48..a76624ef5dc9 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; @@ -50,6 +50,11 @@ public static RpcFailure decode(ByteBuf buf) { return new RpcFailure(requestId, errorString); } + @Override + public int hashCode() { + return Objects.hashCode(requestId, errorString); + } + @Override public boolean equals(Object other) { if (other instanceof RpcFailure) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index cdee0b0e0316..96213794a801 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -17,26 +17,25 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements RequestMessage { +public final class RpcRequest extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ public final long requestId; - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.message = message; } @Override @@ -44,26 +43,36 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); + return requestId == o.requestId && super.equals(o); } return false; } @@ -72,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 0a62e09a8115..bae866e14a1e 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -17,44 +17,62 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { +public final class RpcResponse extends AbstractResponseMessage { public final long requestId; - public final byte[] response; - public RpcResponse(long requestId, byte[] response) { + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.response = response; } @Override public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); + return requestId == o.requestId && super.equals(o); } return false; } @@ -63,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("response", response) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java new file mode 100644 index 000000000000..26747ee55b4d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Message indicating an error when transferring a stream. + */ +public final class StreamFailure extends AbstractMessage implements ResponseMessage { + public final String streamId; + public final String error; + + public StreamFailure(String streamId, String error) { + this.streamId = streamId; + this.error = error; + } + + @Override + public Type type() { return Type.StreamFailure; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + Encoders.Strings.encode(buf, error); + } + + public static StreamFailure decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + String error = Encoders.Strings.decode(buf); + return new StreamFailure(streamId, error); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, error); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamFailure) { + StreamFailure o = (StreamFailure) other; + return streamId.equals(o.streamId) && error.equals(o.error); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("error", error) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java new file mode 100644 index 000000000000..35af5a84ba6b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Request to stream data from the remote end. + *

      + * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before + * the data can be streamed. + */ +public final class StreamRequest extends AbstractMessage implements RequestMessage { + public final String streamId; + + public StreamRequest(String streamId) { + this.streamId = streamId; + } + + @Override + public Type type() { return Type.StreamRequest; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(streamId); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + } + + public static StreamRequest decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + return new StreamRequest(streamId); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamRequest) { + StreamRequest o = (StreamRequest) other; + return streamId.equals(o.streamId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java new file mode 100644 index 000000000000..51b899930f72 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link StreamRequest} when the stream has been successfully opened. + *

      + * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; + + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } + + @Override + public Type type() { return Type.StreamResponse; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(streamId); + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + buf.writeLong(byteCount); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new StreamFailure(streamId, error); + } + + public static StreamResponse decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + long byteCount = buf.readLong(); + return new StreamResponse(streamId, byteCount, null); + } + + @Override + public int hashCode() { + return Objects.hashCode(byteCount, streamId, body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamResponse) { + StreamResponse o = (StreamResponse) other; + return byteCount == o.byteCount && streamId.equals(o.streamId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("byteCount", byteCount) + .add("body", body()) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 185ba2ef3bb1..68381037d689 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,13 +73,16 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } + client.setClientId(appId); + if (encrypt) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { throw new RuntimeException( @@ -86,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index cad76ab7aa54..e52b526f09c7 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -18,38 +18,50 @@ package org.apache.spark.network.sasl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage implements Encodable { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; - public final byte[] payload; - public SaslMessage(String appId, byte[] payload) { + public SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + public SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); this.appId = appId; - this.payload = payload; } + @Override + public Type type() { return Type.User; } + @Override public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static SaslMessage decode(ByteBuf buf) { @@ -59,7 +71,8 @@ public static SaslMessage decode(ByteBuf buf) { } String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index be6165caf3c7..c215bd9d1504 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,23 +74,36 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -107,6 +124,11 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, ByteBuffer message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); @@ -114,9 +136,18 @@ public StreamManager getStreamManager() { @Override public void connectionTerminated(TransportClient client) { - if (saslServer != null) { - saslServer.dispose(); + try { + delegate.connectionTerminated(client); + } finally { + if (saslServer != null) { + saslServer.dispose(); + } } } + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecb..3843406b2740 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,7 +26,7 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e86..6ed61da5c7ef 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index c95e64e8e2cd..e671854da1ca 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -24,13 +24,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; - -import com.google.common.base.Preconditions; +import org.apache.spark.network.client.TransportClient; /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually @@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager { /** State of a single stream. */ private static class StreamState { + final String appId; final Iterator buffers; // The channel associated to the stream @@ -53,7 +54,8 @@ private static class StreamState { // that the caller only requests each chunk one at a time, in order. int curChunk = 0; - StreamState(Iterator buffers) { + StreamState(String appId, Iterator buffers) { + this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); } } @@ -109,15 +111,34 @@ public void connectionTerminated(Channel channel) { } } + @Override + public void checkAuthorization(TransportClient client, long streamId) { + if (client.getClientId() != null) { + StreamState state = streams.get(streamId); + Preconditions.checkArgument(state != null, "Unknown stream ID."); + if (!client.getClientId().equals(state.appId)) { + throw new SecurityException(String.format( + "Client %s not authorized to read stream %d (app %s).", + client.getClientId(), + streamId, + state.appId)); + } + } + } + /** * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a * client connection is closed before the iterator is fully drained, then the remaining buffers * will all be release()'d. + * + * If an app ID is provided, only callers who've authenticated with the given app ID will be + * allowed to fetch from this stream. */ - public long registerStream(Iterator buffers) { + public long registerStream(String appId, Iterator buffers) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(buffers)); + streams.put(myStreamId, new StreamState(appId, buffers)); return myStreamId; } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2ba92a40f8b0..ee1c68369947 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,11 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -24,6 +29,9 @@ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -38,7 +46,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - byte[] message, + ByteBuffer message, RpcResponseCallback callback); /** @@ -47,9 +55,41 @@ public abstract void receive( */ public abstract StreamManager getStreamManager(); + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, ByteBuffer message) { + receive(client, message, ONE_WAY_CALLBACK); + } + /** * Invoked when the connection associated with the given client has been invalidated. * No further requests will come from this client. */ public void connectionTerminated(TransportClient client) { } + + public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(ByteBuffer response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 929f789bf9d2..3f0155957a14 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -20,6 +20,7 @@ import io.netty.channel.Channel; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; /** * The StreamManager is used to fetch individual chunks from a stream. This is used in @@ -45,6 +46,19 @@ public abstract class StreamManager { */ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + /** + * Called in response to a stream() request. The returned data is streamed to the client + * through a single TCP connection. + * + * Note the streamId argument is not related to the similarly named argument in the + * {@link #getChunk(long, int)} method. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + */ + public ManagedBuffer openStream(String streamId) { + throw new UnsupportedOperationException(); + } + /** * Associates a stream with a single client connection, which is guaranteed to be the only reader * of the stream. The getChunk() method will be called serially on this connection and once the @@ -60,4 +74,12 @@ public void registerChannel(Channel channel, long streamId) { } * to read from the associated streams again, so any state can be cleaned up. */ public void connectionTerminated(Channel channel) { } + + /** + * Verify that the client is authorized to read from the given stream. + * + * @throws SecurityException If client is not authorized. + */ + public void checkAuthorization(TransportClient client, long streamId) { } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e..09435bcbab35 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } + } } } + ctx.fireUserEventTriggered(evt); } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e5159ab56d0d..c864d7ce16bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,9 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -25,16 +28,21 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.util.NettyUtils; /** @@ -71,11 +79,18 @@ public TransportRequestHandler( @Override public void exceptionCaught(Throwable cause) { + rpcHandler.exceptionCaught(cause, reverseClient); } @Override public void channelUnregistered() { - streamManager.connectionTerminated(channel); + if (streamManager != null) { + try { + streamManager.connectionTerminated(channel); + } catch (RuntimeException e) { + logger.error("StreamManager connectionTerminated() callback failed.", e); + } + } rpcHandler.connectionTerminated(reverseClient); } @@ -85,6 +100,10 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); + } else if (request instanceof StreamRequest) { + processStreamRequest((StreamRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -97,6 +116,7 @@ private void processFetchRequest(final ChunkFetchRequest req) { ManagedBuffer buf; try { + streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { @@ -109,12 +129,27 @@ private void processFetchRequest(final ChunkFetchRequest req) { respond(new ChunkFetchSuccess(req.streamChunkId, buf)); } + private void processStreamRequest(final StreamRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); + ManagedBuffer buf; + try { + buf = streamManager.openStream(req.streamId); + } catch (Exception e) { + logger.error(String.format( + "Error opening stream %s for request from %s", req.streamId, client), e); + respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(new StreamResponse(req.streamId, buf.size(), buf)); + } + private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -125,6 +160,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); + } + } + + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8..baae235e0220 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7..b3d8e0cd7cdc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -164,32 +164,32 @@ private static boolean isSymlink(File file) throws IOException { */ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. @@ -205,10 +205,10 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } - + /** * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is + * internal use. If no suffix is provided a direct conversion of the provided default is * attempted. */ private static long parseByteString(String str, ByteUnit unit) { @@ -217,7 +217,7 @@ private static long parseByteString(String str, ByteUnit unit) { try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,14 +228,14 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { String timeError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + @@ -248,7 +248,7 @@ private static long parseByteString(String str, ByteUnit unit) { /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { @@ -264,7 +264,7 @@ public static long byteStringAsBytes(String str) { public static long byteStringAsKb(String str) { return parseByteString(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -284,4 +284,20 @@ public static long byteStringAsMb(String str) { public static long byteStringAsGb(String str) { return parseByteString(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java index 57113ed12d41..922c37a10efd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -15,6 +15,24 @@ * limitations under the License. */ +/* + * Based on LimitedInputStream.java from Google Guava + * + * Copyright (C) 2007 The Guava Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.network.util; import java.io.FilterInputStream; diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 26c6399ce7db..caa7260bc828 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -89,13 +89,8 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static ByteToMessageDecoder createFrameDecoder() { - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 8 - // lengthAdjustment = -8, i.e. exclude the 8 byte length itself - // initialBytesToStrip = 8, i.e. strip out the length field itself - return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + public static TransportFrameDecoder createFrameDecoder() { + return new TransportFrameDecoder(); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955..115135d44adb 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java new file mode 100644 index 000000000000..a466c729154a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.Iterator; +import java.util.LinkedList; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +/** + * A customized frame decoder that allows intercepting raw data. + *

      + * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * needs), except it allows an interceptor to be installed to read data directly before it's + * framed. + *

      + * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's + * decoded, instead of building as many frames as the current buffer allows and dispatching + * all of them. This allows a child handler to install an interceptor if needed. + *

      + * If an interceptor is installed, framing stops, and data is instead fed directly to the + * interceptor. When the interceptor indicates that it doesn't need to read any more data, + * framing resumes. Interceptors should not hold references to the data buffers provided + * to their handle() method. + */ +public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { + + public static final String HANDLER_NAME = "frameDecoder"; + private static final int LENGTH_SIZE = 8; + private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); + + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; + private volatile Interceptor interceptor; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } + + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); + } + } + } + + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } + + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + return nextFrameSize; + } + + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } + } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; + } + + private ByteBuf decodeNext() throws Exception { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + return null; + } + + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); + } + + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } + + return frame; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + for (ByteBuf b : buffers) { + b.release(); + } + if (interceptor != null) { + interceptor.channelInactive(); + } + frameLenBuf.release(); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (interceptor != null) { + interceptor.exceptionCaught(cause); + } + super.exceptionCaught(ctx, cause); + } + + public void setInterceptor(Interceptor interceptor) { + Preconditions.checkState(this.interceptor == null, "Already have an interceptor."); + this.interceptor = interceptor; + } + + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { + interceptor = null; + } + return interceptor != null; + } + + public static interface Interceptor { + + /** + * Handles data received from the remote end. + * + * @param data Buffer containing data. + * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor. + */ + boolean handle(ByteBuf data) throws Exception; + + /** Called if an exception is thrown in the channel pipeline. */ + void exceptionCaught(Throwable cause) throws Exception; + + /** Called if the channel is closed and the interceptor is still installed. */ + void channelInactive() throws Exception; + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed..70c849d60e0a 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,12 +79,17 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @@ -101,7 +107,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -117,6 +126,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index d500bc3c98a7..6c8dd742f4b6 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,10 +35,14 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; @@ -78,8 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); + testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -88,10 +94,14 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); + // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the + // channel and cannot be tested like this. + testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0))); + testServerToClient(new StreamFailure("anId", "this is an error")); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d5..f9b5bf96d621 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.junit.*; +import static org.junit.Assert.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -60,7 +61,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override @@ -84,13 +85,16 @@ public void tearDown() { @Test public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -110,15 +114,15 @@ public StreamManager getStreamManager() { // First completes quickly (semaphore starts at 1). TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client.sendRpc(new byte[0], callback0); + client.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); - assert (callback0.success.length == response.length); + assertEquals(responseSize, callback0.successLength); } // Second times out after 2 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client.sendRpc(new byte[0], callback1); + client.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(4 * 1000); assert (callback1.failure != null); assert (callback1.failure instanceof IOException); @@ -131,13 +135,16 @@ public StreamManager getStreamManager() { @Test public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -158,7 +165,7 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); + client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); assert (callback0.failure instanceof IOException); assert (!client0.isActive()); @@ -170,10 +177,10 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); + client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); } } @@ -191,7 +198,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -218,9 +228,10 @@ public StreamManager getStreamManager() { synchronized (callback0) { // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); + assertTrue(callback0.failure instanceof IOException); } synchronized (callback1) { @@ -235,13 +246,13 @@ public StreamManager getStreamManager() { */ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - byte[] success; + int successLength = -1; Throwable failure; @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { synchronized(this) { - success = response; + successLength = response.remaining(); this.notifyAll(); } } @@ -258,7 +269,7 @@ public void onFailure(Throwable e) { public void onSuccess(int chunkIndex, ManagedBuffer buffer) { synchronized(this) { try { - success = buffer.nioByteBuffer().array(); + successLength = buffer.nioByteBuffer().remaining(); this.notifyAll(); } catch (IOException e) { // weird diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f0..9e9be98c140b 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,14 +17,16 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -39,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -46,17 +49,21 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -64,12 +71,18 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -93,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -106,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -158,6 +172,27 @@ public void sendSuccessAndFailure() throws Exception { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(JavaUtils.stringToBytes(message)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java new file mode 100644 index 000000000000..9c49556927f0 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import com.google.common.io.Files; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class StreamSuite { + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + private static TransportServer server; + private static TransportClientFactory clientFactory; + private static File testFile; + private static File tempDir; + + private static ByteBuffer emptyBuffer; + private static ByteBuffer smallBuffer; + private static ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + @BeforeClass + public static void setUp() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final StreamManager streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public ManagedBuffer openStream(String streamId) { + switch (streamId) { + case "largeBuffer": + return new NioManagedBuffer(largeBuffer); + case "smallBuffer": + return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + throw new IllegalArgumentException("Invalid stream: " + streamId); + } + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + if (tempDir != null) { + for (File f : tempDir.listFiles()) { + f.delete(); + } + tempDir.delete(); + } + } + + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + + @Test + public void testSingleStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + + @Test + public void testMultipleStreams() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + for (int i = 0; i < 20; i++) { + StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], + TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } + } finally { + client.close(); + } + } + + @Test + public void testConcurrentStreams() throws Throwable { + ExecutorService executor = Executors.newFixedThreadPool(20); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + try { + List tasks = new ArrayList<>(); + for (int i = 0; i < 20; i++) { + StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], + TimeUnit.SECONDS.toMillis(20)); + tasks.add(task); + executor.submit(task); + } + + executor.shutdown(); + assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS)); + for (StreamTask task : tasks) { + task.check(); + } + } finally { + executor.shutdownNow(); + client.close(); + } + } + + private static class StreamTask implements Runnable { + + private final TransportClient client; + private final String streamId; + private final long timeoutMs; + private Throwable error; + + StreamTask(TransportClient client, String streamId, long timeoutMs) { + this.client = client; + this.streamId = streamId; + this.timeoutMs = timeoutMs; + } + + @Override + public void run() { + ByteBuffer srcBuffer = null; + OutputStream out = null; + File outFile = null; + try { + ByteArrayOutputStream baos = null; + + switch (streamId) { + case "largeBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = largeBuffer; + break; + case "smallBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = smallBuffer; + break; + case "file": + outFile = File.createTempFile("data", ".tmp", tempDir); + out = new FileOutputStream(outFile); + break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; + default: + throw new IllegalArgumentException(streamId); + } + + TestCallback callback = new TestCallback(out); + client.stream(streamId, callback); + waitForCompletion(callback); + + if (srcBuffer == null) { + assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + } else { + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] result = baos.toByteArray(); + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } catch (Throwable t) { + error = t; + } finally { + if (out != null) { + try { + out.close(); + } catch (Exception e) { + // ignore. + } + } + if (outFile != null) { + outFile.delete(); + } + } + } + + public void check() throws Throwable { + if (error != null) { + throw error; + } + } + + private void waitForCompletion(TestCallback callback) throws Exception { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (callback) { + while (!callback.completed && now < deadline) { + callback.wait(deadline - now); + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", callback.completed); + assertNull(callback.error); + } + + } + + private static class TestCallback implements StreamCallback { + + private final OutputStream out; + public volatile boolean completed; + public volatile Throwable error; + + TestCallback(OutputStream out) { + this.out = out; + this.completed = false; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + byte[] tmp = new byte[buf.remaining()]; + buf.get(tmp); + out.write(tmp); + } + + @Override + public void onComplete(String streamId) throws IOException { + out.close(); + synchronized (this) { + completed = true; + notifyAll(); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) { + error = cause; + synchronized (this) { + completed = true; + notifyAll(); + } + } + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 38113a918f79..83c90f9eff2b 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException { return underlying.convertToNetty(); } + @Override + public int hashCode() { + return underlying.hashCode(); + } + @Override public boolean equals(Object other) { if (other instanceof ManagedBuffer) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb9..dac7d4a5b0a0 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -50,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -74,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -177,4 +179,36 @@ public void closeBlockClientsWithFactory() throws IOException { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a9..128f7cba7435 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,9 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; + +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -26,18 +29,23 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test - public void handleSuccessfulFetch() { + public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); @@ -51,7 +59,7 @@ public void handleSuccessfulFetch() { } @Test - public void handleFailedFetch() { + public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); @@ -64,7 +72,7 @@ public void handleFailedFetch() { } @Test - public void clearAllOutstandingRequests() { + public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); @@ -83,23 +91,24 @@ public void clearAllOutstandingRequests() { } @Test - public void handleSuccessfulRPC() { + public void handleSuccessfulRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); assertEquals(1, handler.numOutstandingRequests()); - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); assertEquals(0, handler.numOutstandingRequests()); } @Test - public void handleFailedRPC() { + public void handleFailedRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); @@ -112,4 +121,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index be6632bb8cf4..751516b9d82a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -17,11 +17,12 @@ package org.apache.spark.network.sasl; -import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.*; import static org.mockito.Mockito.*; import java.io.File; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -56,6 +57,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -122,36 +124,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, UTF_8)); - cb.onSuccess("Pong".getBytes(UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); + // There should be 2 terminated events; one for the client, one for the server. + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -204,7 +223,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -235,11 +254,11 @@ public void testFileRegionEncryption() throws Exception { final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; System.setProperty(blockSizeConf, "1k"); - final AtomicReference response = new AtomicReference(); + final AtomicReference response = new AtomicReference<>(); final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -321,7 +340,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); @@ -332,6 +352,31 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { } } + @Test + public void testRpcHandlerDelegate() throws Exception { + // Tests all delegates exception for receive(), which is more complicated and already handled + // by all other tests. + RpcHandler handler = mock(RpcHandler.class); + RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null); + + saslHandler.getStreamManager(); + verify(handler).getStreamManager(); + + saslHandler.connectionTerminated(null); + verify(handler).connectionTerminated(any(TransportClient.class)); + + saslHandler.exceptionCaught(null, null); + verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); + } + + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; @@ -347,7 +392,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java new file mode 100644 index 000000000000..d4de4a941d48 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class TransportFrameDecoderSuite { + + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + + @Test + public void testFrameDecoding() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); + } + + @Test + public void testInterception() throws Exception { + final int interceptedReads = 3; + TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + + byte[] data = new byte[8]; + ByteBuf len = Unpooled.copyLong(8 + data.length); + ByteBuf dataBuf = Unpooled.wrappedBuffer(data); + + try { + decoder.setInterceptor(interceptor); + for (int i = 0; i < interceptedReads; i++) { + decoder.channelRead(ctx, dataBuf); + assertEquals(0, dataBuf.refCnt()); + dataBuf = Unpooled.wrappedBuffer(data); + } + decoder.channelRead(ctx, len); + decoder.channelRead(ctx, dataBuf); + verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); + } finally { + release(len); + release(dataBuf); + } + } + + @Test + public void testRetainedFrames() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + + final AtomicInteger count = new AtomicInteger(); + final List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } + } + } + + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testNegativeFrameSize() throws Exception { + testInvalidFrame(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyFrame() throws Exception { + // 8 because frame size includes the frame length. + testInvalidFrame(8); + } + + @Test(expected = IllegalArgumentException.class) + public void testLargeFrame() throws Exception { + // Frame length includes the frame size field, so need to add a few more bytes. + testInvalidFrame(Integer.MAX_VALUE + 9); + } + + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } + return data; + } + + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + + private void testInvalidFrame(long size) throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ByteBuf frame = Unpooled.copyLong(size); + try { + decoder.channelRead(ctx, frame); + } finally { + release(frame); + } + } + + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { + + private int remainingReads; + + MockInterceptor(int readCount) { + this.remainingReads = readCount; + } + + @Override + public boolean handle(ByteBuf data) throws Exception { + data.readerIndex(data.readerIndex() + data.readableBytes()); + assertFalse(data.isReadable()); + remainingReads -= 1; + return remainingReads != 0; + } + + @Override + public void exceptionCaught(Throwable cause) throws Exception { + + } + + @Override + public void channelInactive() throws Exception { + + } + + } + +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 532463e96fbb..70ba5cb1995b 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -43,6 +43,22 @@ ${project.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j @@ -63,14 +79,8 @@ test - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} log4j diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index db9dc4f17cee..f22187a01db0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -17,11 +17,13 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; -import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,10 +33,10 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; + /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -46,16 +48,18 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); - private final ExternalShuffleBlockResolver blockManager; + @VisibleForTesting + final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf) { - this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf)); + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + this(new OneForOneStreamManager(), + new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } /** Enables mocking out the StreamManager and BlockManager. */ @VisibleForTesting - ExternalShuffleBlockHandler( + public ExternalShuffleBlockHandler( OneForOneStreamManager streamManager, ExternalShuffleBlockResolver blockManager) { this.streamManager = streamManager; @@ -63,8 +67,8 @@ public ExternalShuffleBlockHandler(TransportConf conf) { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -74,19 +78,21 @@ protected void handleMessage( RpcResponseCallback callback) { if (msgObj instanceof OpenBlocks) { OpenBlocks msg = (OpenBlocks) msgObj; - List blocks = Lists.newArrayList(); + checkAuth(client, msg.appId); + List blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); } - long streamId = streamManager.registerStream(blocks.iterator()); + long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; + checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); @@ -105,4 +111,30 @@ public StreamManager getStreamManager() { public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + + /** + * Register an (application, executor) with the given shuffle info. + * + * The "re-" is meant to highlight the intended use of this method -- when this service is + * restarted, this is used to restore the state of executors from before the restart. Normal + * registration will happen via a message handled in receive() + * + * @param appExecId + * @param executorInfo + */ + public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) { + blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo); + } + + public void close() { + blockManager.close(); + } + + private void checkAuth(TransportClient client, String appId) { + if (client.getClientId() != null && !client.getClientId().equals(appId)) { + throw new SecurityException(String.format( + "Client for %s not authorized for application %s.", client.getClientId(), appId)); + } + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 022ed88a1648..fe933ed650ca 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -17,19 +17,24 @@ package org.apache.spark.network.shuffle; -import java.io.DataInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Iterator; -import java.util.Map; +import java.io.*; +import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import com.google.common.collect.Maps; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; +import org.iq80.leveldb.Options; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,32 +50,95 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); + private static final ObjectMapper mapper = new ObjectMapper(); + /** + * This a common prefix to the key for each app registration we stick in leveldb, so they + * are easy to find, since leveldb lets you search based on prefix. + */ + private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; + private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + // Map containing all registered executors' metadata. - private final ConcurrentMap executors; + @VisibleForTesting + final ConcurrentMap executors; // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; private final TransportConf conf; - public ExternalShuffleBlockResolver(TransportConf conf) { - this(conf, Executors.newSingleThreadExecutor( + @VisibleForTesting + final File registeredExecutorFile; + @VisibleForTesting + final DB db; + + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) + throws IOException { + this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( // Add `spark` prefix because it will run in NM in Yarn mode. NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); } // Allows tests to have more control over when directories are cleaned up. @VisibleForTesting - ExternalShuffleBlockResolver(TransportConf conf, Executor directoryCleaner) { + ExternalShuffleBlockResolver( + TransportConf conf, + File registeredExecutorFile, + Executor directoryCleaner) throws IOException { this.conf = conf; - this.executors = Maps.newConcurrentMap(); + this.registeredExecutorFile = registeredExecutorFile; + if (registeredExecutorFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + DB tmpDb; + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + registeredExecutorFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", registeredExecutorFile, e); + if (registeredExecutorFile.isDirectory()) { + for (File f : registeredExecutorFile.listFiles()) { + if (!f.delete()) { + logger.warn("error deleting {}", f.getPath()); + } + } + } + if (!registeredExecutorFile.delete()) { + logger.warn("error deleting {}", registeredExecutorFile.getPath()); + } + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb); + executors = reloadRegisteredExecutors(tmpDb); + db = tmpDb; + } else { + db = null; + executors = Maps.newConcurrentMap(); + } this.directoryCleaner = directoryCleaner; } @@ -81,6 +149,15 @@ public void registerExecutor( ExecutorShuffleInfo executorInfo) { AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); + try { + if (db != null) { + byte[] key = dbAppExecKey(fullId); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + db.put(key, value); + } + } catch (Exception e) { + logger.error("Error saving registered executors", e); + } executors.put(fullId, executorInfo); } @@ -106,11 +183,10 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) - || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else if ("hash".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); @@ -136,6 +212,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { // Only touch executors associated with the appId that was removed. if (appId.equals(fullId.appId)) { it.remove(); + if (db != null) { + try { + db.delete(dbAppExecKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } if (cleanupLocalDirs) { logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); @@ -171,7 +254,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); @@ -220,12 +302,23 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } + void close() { + if (db != null) { + try { + db.close(); + } catch (IOException e) { + logger.error("Exception closing leveldb with registered executors", e); + } + } + } + /** Simply encodes an executor's full ID, which is appId + execId. */ - private static class AppExecId { - final String appId; - final String execId; + public static class AppExecId { + public final String appId; + public final String execId; - private AppExecId(String appId, String execId) { + @JsonCreator + public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) { this.appId = appId; this.execId = execId; } @@ -252,4 +345,105 @@ public String toString() { .toString(); } } + + private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(Charsets.UTF_8); + } + + private static AppExecId parseDbAppExecKey(String s) throws IOException { + if (!s.startsWith(APP_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX); + } + String json = s.substring(APP_KEY_PREFIX.length() + 1); + AppExecId parsed = mapper.readValue(json, AppExecId.class); + return parsed; + } + + @VisibleForTesting + static ConcurrentMap reloadRegisteredExecutors(DB db) + throws IOException { + ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); + if (db != null) { + DBIterator itr = db.iterator(); + itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), Charsets.UTF_8); + if (!key.startsWith(APP_KEY_PREFIX)) { + break; + } + AppExecId id = parseDbAppExecKey(key); + ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); + registeredExecutors.put(id, shuffleInfo); + } + } + return registeredExecutors; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + private static void checkVersion(DB db) throws IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != CURRENT_VERSION.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + CURRENT_VERSION); + } + storeVersion(db); + } + } + + private static void storeVersion(DB db) throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); + } + + + public static class StoreVersion { + + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be..58ca87d9d3b1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -78,7 +79,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +138,13 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147e..1b2ddbf1ed91 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 7543b6be4f2a..675820308bd4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.mesos; import java.io.IOException; +import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,11 +55,11 @@ public MesosExternalShuffleClient( public void registerDriverWithShuffleService(String host, int port) throws IOException { checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); client.sendRpc(registerDriver, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.info("Successfully registered app " + appId + " with external shuffle service."); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632..7fbe3384b4d4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -53,7 +55,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -68,12 +70,12 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index cadc8e8369c6..102d4efb8bf3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable { /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ public final String shuffleManager; - public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + @JsonCreator + public ExecutorShuffleInfo( + @JsonProperty("localDirs") String[] localDirs, + @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, + @JsonProperty("shuffleManager") String shuffleManager) { this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index cca8b17c4f12..167ef3310422 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -27,7 +27,7 @@ /** * Initial registration message between an executor and its local shuffle server. - * Returns nothing (empty bye array). + * Returns nothing (empty byte array). */ public class RegisterExecutor extends BlockTransferMessage { public final String appId; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 1c28fc1dff24..94a61d6caadc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -23,6 +23,9 @@ import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * A message sent from the driver to register with the MesosExternalShuffleService. */ diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 382f613ecbb1..0ea631ea14d7 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,7 +18,9 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; import com.google.common.collect.Lists; import org.junit.After; @@ -27,9 +29,12 @@ import org.junit.Test; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; @@ -39,44 +44,45 @@ import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.shuffle.BlockFetchingListener; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; +import org.apache.spark.network.shuffle.OneForOneBlockFetcher; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { - static ExternalShuffleBlockHandler handler; + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private static final long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; + static SecretKeyHolder secretKeyHolder; TransportClientFactory clientFactory; - /** Provides a secret key holder which always returns the given secret key. */ - static class TestSecretKeyHolder implements SecretKeyHolder { - - private final String secretKey; - - TestSecretKeyHolder(String secretKey) { - this.secretKey = secretKey; - } - - @Override - public String getSaslUser(String appId) { - return "user"; - } - @Override - public String getSecretKey(String appId) { - return secretKey; - } - } - - @BeforeClass public static void beforeAll() throws IOException { - SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); + secretKeyHolder = mock(SecretKeyHolder.class); + when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password"); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); server = context.createServer(Arrays.asList(bootstrap)); } @@ -99,23 +105,27 @@ public void afterEach() { public void testGoodClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test public void testBadClient() { + SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class); + when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); try { // Bootstrap should fail on startup. clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + fail("Connection should have failed."); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } @@ -128,7 +138,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -136,7 +146,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -149,7 +159,7 @@ public void testNoSaslServer() { TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( Lists.newArrayList( - new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportServer server = context.createServer(); try { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -160,10 +170,113 @@ public void testNoSaslServer() { } } + /** + * This test is not actually testing SASL behavior, but testing that the shuffle service + * performs correct authorization checks based on the SASL authentication data. + */ + @Test + public void testAppIsolation() throws Exception { + // Start a new server with the correct RPC handler to serve block data. + ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class); + ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler( + new OneForOneStreamManager(), blockResolver); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + TransportContext blockServerContext = new TransportContext(conf, blockHandler); + TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); + + TransportClient client1 = null; + TransportClient client2 = null; + TransportClientFactory clientFactory2 = null; + try { + // Create a client, and make a request to fetch blocks from a different app. + clientFactory = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + client1 = clientFactory.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + final AtomicReference exception = new AtomicReference<>(); + + BlockFetchingListener listener = new BlockFetchingListener() { + @Override + public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + notifyAll(); + } + + @Override + public synchronized void onBlockFetchFailure(String blockId, Throwable t) { + exception.set(t); + notifyAll(); + } + }; + + String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" }; + OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", + blockIds, listener); + synchronized (listener) { + fetcher.start(); + listener.wait(); + } + checkSecurityException(exception.get()); + + // Register an executor so that the next steps work. + ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); + RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); + + // Make a successful request to fetch blocks, which creates a new stream. But do not actually + // fetch any blocks, to keep the stream open. + OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); + long streamId = stream.streamId; + + // Create a second client, authenticated with a different app ID, and try to read from + // the stream created for the previous app. + clientFactory2 = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + client2 = clientFactory2.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { + notifyAll(); + } + + @Override + public synchronized void onFailure(int chunkIndex, Throwable t) { + exception.set(t); + notifyAll(); + } + }; + + exception.set(null); + synchronized (callback) { + client2.fetchChunk(streamId, 0, callback); + callback.wait(); + } + checkSecurityException(exception.get()); + } finally { + if (client1 != null) { + client1.close(); + } + if (client2 != null) { + client2.close(); + } + if (clientFactory2 != null) { + clientFactory2.close(); + } + blockServer.close(); + } + } + /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } @@ -172,4 +285,10 @@ public StreamManager getStreamManager() { return new OneForOneStreamManager(); } } + + private void checkSecurityException(Throwable t) { + assertNotNull("No exception was caught.", t); + assertTrue("Expected SecurityException.", + t.getMessage().contains(SecurityException.class.getName())); + } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a..86c8609e7070 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 73374cdc77a2..9379412155e8 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,22 +77,25 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); - ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(stream.capture()); - Iterator buffers = (Iterator) stream.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + Iterator buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); @@ -102,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -110,7 +113,7 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -118,7 +121,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d02f4f0fdb68..60a1b8b0451f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -21,9 +21,12 @@ import java.io.InputStream; import java.io.InputStreamReader; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -39,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { @@ -59,8 +62,8 @@ public static void afterAll() { } @Test - public void testBadRequests() { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + public void testBadRequests() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); @@ -80,7 +83,7 @@ public void testBadRequests() { // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); try { resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); fail("Should have failed"); @@ -91,9 +94,9 @@ public void testBadRequests() { @Test public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); @@ -110,9 +113,9 @@ public void testSortShuffleBlocks() throws IOException { @Test public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + dataContext.createExecutorInfo("hash")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); @@ -126,4 +129,28 @@ public void testHashShuffleBlocks() throws IOException { block1Stream.close(); assertEquals(hashBlock1, block1); } + + @Test + public void jsonSerializationOfExecutorRegistration() throws IOException { + ObjectMapper mapper = new ObjectMapper(); + AppExecId appId = new AppExecId("foo", "bar"); + String appIdJson = mapper.writeValueAsString(appId); + AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); + assertEquals(parsedAppId, appId); + + ExecutorShuffleInfo shuffleInfo = + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + String shuffleJson = mapper.writeValueAsString(shuffleInfo); + ExecutorShuffleInfo parsedShuffleInfo = + mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); + assertEquals(parsedShuffleInfo, shuffleInfo); + + // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // its not legacy yet, but keeping this here in case anybody changes it + String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; + assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); + String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); + } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index d9d9c1bf2f17..532d7ab8d01b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,14 +35,14 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); resolver.applicationRemoved("app", false /* cleanup */); @@ -65,7 +65,8 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, noThreadExecutor); + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -83,7 +84,7 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); @@ -99,7 +100,7 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 39aa49911d9c..5e706bf40169 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -49,8 +49,8 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; - static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; @@ -91,8 +91,8 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(conf); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index d4ec1956c1e2..08ddb3755bd0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,12 +39,13 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before - public void beforeEach() { - TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf)); + public void beforeEach() throws IOException { + TransportContext context = + new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, new TestSecretKeyHolder("my-app-id", "secret")); this.server = context.createServer(Arrays.asList(bootstrap)); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd0..2590b9ce4c1f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 1ad0d72ae5ec..3a6ef0d3f847 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -20,7 +20,9 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import com.google.common.collect.ImmutableMap; @@ -67,13 +69,13 @@ public void afterEach() { public void testNoFailures() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // Immediately return both blocks successfully. ImmutableMap.builder() .put("b0", block0) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -86,13 +88,13 @@ public void testNoFailures() throws IOException { public void testUnrecoverableFailure() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0 throws a non-IOException error, so it will be failed without retry. ImmutableMap.builder() .put("b0", new RuntimeException("Ouch!")) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -105,7 +107,7 @@ public void testUnrecoverableFailure() throws IOException { public void testSingleIOExceptionOnFirst() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // IOException will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() .put("b0", new IOException("Connection failed or something")) @@ -114,8 +116,8 @@ public void testSingleIOExceptionOnFirst() throws IOException { ImmutableMap.builder() .put("b0", block0) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -128,7 +130,7 @@ public void testSingleIOExceptionOnFirst() throws IOException { public void testSingleIOExceptionOnSecond() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // IOException will cause a retry. Since b1 fails, we will not retry b0. ImmutableMap.builder() .put("b0", block0) @@ -136,8 +138,8 @@ public void testSingleIOExceptionOnSecond() throws IOException { .build(), ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -150,7 +152,7 @@ public void testSingleIOExceptionOnSecond() throws IOException { public void testTwoIOExceptions() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, b1's will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -164,8 +166,8 @@ public void testTwoIOExceptions() throws IOException { // b1 returns successfully within 2 retries. ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -178,7 +180,7 @@ public void testTwoIOExceptions() throws IOException { public void testThreeIOExceptions() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, b1's will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -196,8 +198,8 @@ public void testThreeIOExceptions() throws IOException { // This is not reached -- b1 has failed. ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -210,7 +212,7 @@ public void testThreeIOExceptions() throws IOException { public void testRetryAndUnrecoverable() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, subsequent messages will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -226,8 +228,8 @@ public void testRetryAndUnrecoverable() throws IOException { // b2 succeeds in its last retry. ImmutableMap.builder() .put("b2", block2) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -248,10 +250,11 @@ public void testRetryAndUnrecoverable() throws IOException { * subset of the original blocks in a second interaction. */ @SuppressWarnings("unchecked") - private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + private static void performInteractions(List> interactions, + BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c..7ac1ca128aed 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a99f7c4392d3..e2360eff5cfe 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -44,12 +44,21 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.apache.hadoop hadoop-client + + org.slf4j + slf4j-api + provided + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 463f99ef3352..ba6d30a74c67 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,25 +17,21 @@ package org.apache.spark.network.yarn; +import java.io.File; import java.nio.ByteBuffer; import java.util.List; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.server.api.AuxiliaryService; -import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; -import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; -import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; -import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.apache.hadoop.yarn.server.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; -import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; @@ -79,11 +75,26 @@ public class YarnShuffleService extends AuxiliaryService { private TransportServer shuffleServer = null; // Handles registering executors and opening shuffle blocks - private ExternalShuffleBlockHandler blockHandler; + @VisibleForTesting + ExternalShuffleBlockHandler blockHandler; + + // Where to store & reload executor info for recovering state after an NM restart + @VisibleForTesting + File registeredExecutorFile; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); + instance = this; } /** @@ -100,11 +111,24 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = + findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); + + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - blockHandler = new ExternalShuffleBlockHandler(transportConf); + try { + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + } catch (Exception e) { + logger.error("Failed to initialize external shuffle service", e); + } List bootstraps = Lists.newArrayList(); if (authEnabled) { @@ -116,9 +140,13 @@ protected void serviceInit(Configuration conf) { SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); TransportContext transportContext = new TransportContext(transportConf, blockHandler); shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}.", port, authEnabledString); + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } @Override @@ -161,6 +189,16 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } + private File findRegisteredExecutorFile(String[] localDirs) { + for (String dir: localDirs) { + File f = new File(dir, "registeredExecutors.ldb"); + if (f.exists()) { + return f; + } + } + return new File(localDirs[0], "registeredExecutors.ldb"); + } + /** * Close the shuffle server to clean up any associated state. */ @@ -170,6 +208,9 @@ protected void serviceStop() { if (shuffleServer != null) { shuffleServer.close(); } + if (blockHandler != null) { + blockHandler.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -180,5 +221,4 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } - } diff --git a/pom.xml b/pom.xml index be0dac953abf..c560e13641c6 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -86,8 +86,9 @@ + tags core - bagel + bagel graphx mllib tools @@ -97,6 +98,7 @@ sql/catalyst sql/core sql/hive + docker-integration-tests unsafe assembly external/twitter @@ -104,6 +106,7 @@ external/flume-sink external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl @@ -134,11 +137,12 @@ 2.4.0 org.spark-project.hive - 0.13.1a + 1.2.1.spark - 0.13.1 + 1.2.1 10.10.1.1 1.7.0 + 1.6.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -149,19 +153,41 @@ 1.7.7 hadoop2 0.7.1 - 1.9.16 - 1.2.1 + 1.9.40 + 1.4.0 + + 0.10.1 + 4.3.2 + + 3.1 3.4.1 - 2.10.4 + + 3.2.2 + 2.10.5 2.10 ${scala.version} org.scala-lang 1.9.13 2.4.4 - 1.1.1.7 + 1.1.2 1.1.2 + 1.2.0-incubating + 1.10 + + 2.6 + + 3.3.2 + 3.2.10 + 2.7.8 + 1.9 + 2.9 + 3.5.2 + 1.3.9 + 0.9.2 + ${java.home} + spring-releases Spring Release Repository https://repo.spring.io/libs-release - true + false false @@ -312,9 +346,25 @@ scalatest_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + com.twitter chill_${scala.binary.version} @@ -345,6 +395,14 @@ + + + org.apache.xbean + xbean-asm5-shaded + 4.4 + @@ -561,7 +644,7 @@ org.roaringbitmap RoaringBitmap - 0.4.5 + 0.5.11 commons-net @@ -608,6 +691,11 @@ jackson-databind ${fasterxml.jackson.version} + + com.fasterxml.jackson.core + jackson-annotations + ${fasterxml.jackson.version} + @@ -624,15 +712,26 @@ com.sun.jersey jersey-server - 1.9 + ${jersey.version} ${hadoop.deps.scope} com.sun.jersey jersey-core - 1.9 + ${jersey.version} ${hadoop.deps.scope} + + com.sun.jersey + jersey-json + ${jersey.version} + + + stax + stax-api + + + org.scala-lang scala-compiler @@ -679,7 +778,7 @@ junit junit - 4.10 + 4.11 test @@ -697,7 +796,48 @@ com.novocode junit-interface - 0.10 + 0.11 + test + + + com.spotify + docker-client + shaded + 3.2.1 + test + + + guava + com.google.guava + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + commons-logging + httpclient + + + commons-logging + commons-logging + + + + + mysql + mysql-connector-java + 5.1.34 + test + + + org.postgresql + postgresql + 9.3-1102-jdbc41 test @@ -1022,58 +1162,503 @@ hive-beeline ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} hive-cli ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} - hive-exec + hive-common ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-shims + + + org.apache.ant + ant + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-exec + + ${hive.version} + ${hive.deps.scope} + + + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + ${hive.group} + hive-ant + + + + ${hive.group} + spark-client + + + + + ant + ant + + + org.apache.ant + ant + com.esotericsoftware.kryo kryo + + commons-codec + commons-codec + + + commons-httpclient + commons-httpclient + org.apache.avro avro-mapred + + + org.apache.calcite + calcite-core + + + org.apache.curator + apache-curator + + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + ${hive.group} hive-jdbc ${hive.version} - ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-common + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + ${hive.group} hive-metastore ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-shims + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + org.mortbay.jetty + servlet-api + + + com.google.guava + guava + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + ${hive.group} hive-serde ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-common + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + com.google.code.findbugs + jsr305 + + + org.apache.avro + avro + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-service + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + org.apache.curator + curator-framework + + + org.apache.curator + curator-recipes + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + + + + + ${hive.group} + hive-shims + ${hive.version} + ${hive.deps.scope} + + + com.google.guava + guava + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging - commons-logging-api + commons-logging @@ -1095,6 +1680,12 @@ ${parquet.version} ${parquet.test.deps.scope} + + com.twitter + parquet-hadoop-bundle + ${hive.parquet.version} + compile + org.apache.flume flume-ng-core @@ -1135,6 +1726,125 @@ + + org.apache.calcite + calcite-core + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.codehaus.janino + janino + + + + org.hsqldb + hsqldb + + + org.pentaho + pentaho-aggdesigner-algorithm + + + + + org.apache.calcite + calcite-avatica + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.codehaus.janino + janino + ${janino.version} + + + joda-time + joda-time + ${joda.version} + + + org.jodd + jodd-core + ${jodd.version} + + + org.datanucleus + datanucleus-core + ${datanucleus-core.version} + + + org.apache.thrift + libthrift + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + + + org.apache.thrift + libfb303 + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + @@ -1223,6 +1933,7 @@ ${java.version} -target ${java.version} + -Xlint:all,-serial,-path @@ -1236,6 +1947,9 @@ UTF-8 1024m true + + -Xlint:all,-serial,-path + @@ -1259,6 +1973,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} @@ -1267,12 +1983,15 @@ ${project.build.directory}/tmp ${spark.test.home} 1 + false false false - true true + + src false + ${test.exclude.tags} @@ -1293,6 +2012,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} @@ -1303,9 +2024,11 @@ 1 false false - true true + + __not_used__ + ${test.exclude.tags} @@ -1540,7 +2263,7 @@ org.scalastyle scalastyle-maven-plugin - 0.7.0 + 0.8.0 false true @@ -1561,6 +2284,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + true + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins @@ -1748,44 +2495,6 @@ - - mapr3 - - 1.0.3-mapr-3.0.3 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - - mapr4 - - 2.4.1-mapr-1408 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - org.apache.curator - curator-recipes - ${curator.version} - - - org.apache.zookeeper - zookeeper - - - - - org.apache.zookeeper - zookeeper - 3.4.5-mapr-1406 - - - - hive-thriftserver @@ -1799,7 +2508,7 @@ !scala-2.11 - 2.10.4 + 2.10.5 2.10 ${scala.version} org.scala-lang diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f16bf989f200..519052620246 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.4.0" + val previousSparkVersion = "1.5.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 280aac931915..edae59d88266 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -32,602 +32,786 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") */ object MimaExcludes { - def excludes(version: String) = - version match { - case v if v.startsWith("1.5") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // JavaRDDLike is not meant to be extended by user programs - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.partitioner"), - // Modification of private static method - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), - // Mima false positive (was a private[spark] class) - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator"), - // Removing a testing method from a private class - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), - // While private MiMa is still not happy about the changes, - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet"), - // The old JSON RDD is removed in favor of streaming Jackson - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), - // local function inside a method - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") - ) ++ Seq( - // SPARK-8479 Add numNonzeros and numActives to Matrix. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numActives") - ) ++ Seq( - // SPARK-8914 Remove RDDApi - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") - ) ++ Seq( - // SPARK-7292 Provide operator to truncate lineage cheaply - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.RDDCheckpointData"), - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.CheckpointRDD") - ) ++ Seq( - // SPARK-8701 Add input metadata in the batch page. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo") - ) ++ Seq( - // SPARK-6797 Support YARN modes for SparkR - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.PairwiseRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.createRWorker"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.StringRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.BaseRRDD.this") - ) ++ Seq( - // SPARK-7422 add argmax for sparse vectors - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.argmax") - ) ++ Seq( - // SPARK-8906 Move all internal data source classes into execution.datasources - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") - ) ++ Seq( - // SPARK-4751 Dynamic allocation for standalone mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.supportDynamicAllocation") - ) + def excludes(version: String) = version match { + case v if v.startsWith("1.6") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("unsafe"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // SQL columnar is considered private. + excludePackage("org.apache.spark.sql.columnar"), + // The shuffle package is considered private. + excludePackage("org.apache.spark.shuffle"), + // The collections utlities are considered pricate. + excludePackage("org.apache.spark.util.collection") + ) ++ + MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + // MiMa does not deal properly with sealed traits + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") + ) ++ Seq( + // SPARK-11530 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.SQLContext$SQLSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.detachSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.tlSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.defaultSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.currentSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.openSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.createSession") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.preferredNodeLocationData_="), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") + ) ++ Seq( + // SPARK-11485 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), + // SPARK-11541 mark various JDBC dialects as private + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") + ) ++ Seq ( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") + ) ++ Seq( + // SPARK-11766 add toJson to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toJson") + ) ++ Seq( + // SPARK-9065 Support message handler in Kafka Python API + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // The old JSON RDD is removed in favor of streaming Jackson + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-7292 Provide operator to truncate lineage cheaply + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.RDDCheckpointData"), + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.CheckpointRDD") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") + ) ++ Seq( + // SPARK-6797 Support YARN modes for SparkR + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.PairwiseRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.createRWorker"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.StringRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.BaseRRDD.this") + ) ++ Seq( + // SPARK-7422 add argmax for sparse vectors + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.argmax") + ) ++ Seq( + // SPARK-8906 Move all internal data source classes into execution.datasources + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") + ) ++ Seq( + // SPARK-4751 Dynamic allocation for standalone mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) - case v if v.startsWith("1.4") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partioner to JavaRDD, - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.rdd.JdbcRDD.compute"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") - ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though - // the stage class is defined as private[spark] - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") - ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") - ) ++ Seq( - // SPARK-6492 Fix deadlock in SparkContext.stop() - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + - "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") - )++ Seq( - // SPARK-6693 add tostring with max lines and width for matrix - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.toString") - )++ Seq( - // SPARK-6703 Add getOrCreate method to SparkContext - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") - )++ Seq( - // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.mllib.clustering.LDA$EMOptimizer") - ) ++ Seq( - // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.compressed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toDense"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toSparse"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives"), - // SPARK-7681 add SparseVector support for gemv - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.multiply") - ) ++ Seq( - // Execution should never be included as its always internal. - MimaBuild.excludeSparkPackage("sql.execution"), - // This `protected[sql]` method was removed in 1.3.1 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.checkAnalysis"), - // These `private[sql]` class were removed in 1.4.0: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), - // These test support classes were moved out of src/main and into src/test: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), - // TODO: Remove the following rule once ParquetTest has been moved to src/test. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTest") - ) ++ Seq( - // SPARK-7530 Added StreamingContext.getState() - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.StreamingContext.state_=") - ) ++ Seq( - // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some - // unnecessary type bounds in order to fix some compiler warnings that occurred when - // implementing this interface in Java. Note that ShuffleWriter is private[spark]. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.shuffle.ShuffleWriter") - ) ++ Seq( - // SPARK-6888 make jdbc driver handling user definable - // This patch renames some classes to API friendly names. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") - ) + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") + ) ++ Seq( + // Execution should never be included as its always internal. + MimaBuild.excludeSparkPackage("sql.execution"), + // This `protected[sql]` method was removed in 1.3.1 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.checkAnalysis"), + // These `private[sql]` class were removed in 1.4.0: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), + // These test support classes were moved out of src/main and into src/test: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") + ) ++ Seq( + // SPARK-7530 Added StreamingContext.getState() + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") + ) - case v if v.startsWith("1.3") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in the 1.2 build. - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") - ) ++ Seq( - // SPARK-2321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfoImpl.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfo.submissionTime") - ) ++ Seq( - // SPARK-4614 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.randn"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.rand") - ) ++ Seq( - // SPARK-5321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.transpose"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + - "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.isTransposed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.foreachActive") - ) ++ Seq( - // SPARK-5540 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), - // SPARK-5536 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") - ) ++ Seq( - // SPARK-3325 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), - // SPARK-2757 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + - "removeAndGetProcessor") - ) ++ Seq( - // SPARK-5123 (SparkSQL data type change) - alpha component only - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.HashingTF.outputDataType"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.outputDataType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.validateInputType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") - ) ++ Seq( - // SPARK-4014 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.taskAttemptId"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.attemptNumber") - ) ++ Seq( - // SPARK-5166 Spark SQL API stabilization - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") - ) ++ Seq( - // SPARK-5270 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.isEmpty") - ) ++ Seq( - // SPARK-5430 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeReduce"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeAggregate") - ) ++ Seq( - // SPARK-5297 Java FileStream do not work with custom key/values - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") - ) ++ Seq( - // SPARK-5315 Spark Streaming Java API returns Scala DStream - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") - ) ++ Seq( - // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.getCheckpointFiles"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.isCheckpointed") - ) ++ Seq( - // SPARK-4789 Standardize ML Prediction APIs - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") - ) ++ Seq( - // SPARK-5814 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") - ) ++ Seq( - // SPARK-4682 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") - ) ++ Seq( - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") - ) + case v if v.startsWith("1.3") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in the 1.2 build. + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) ++ Seq( + // SPARK-2321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfoImpl.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") + ) ++ Seq( + // SPARK-5540 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") + ) ++ Seq( + // SPARK-3325 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), + // SPARK-2757 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + + "removeAndGetProcessor") + ) ++ Seq( + // SPARK-5123 (SparkSQL data type change) - alpha component only + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.HashingTF.outputDataType"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.outputDataType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.validateInputType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") + ) ++ Seq( + // SPARK-4014 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.taskAttemptId"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.attemptNumber") + ) ++ Seq( + // SPARK-5166 Spark SQL API stabilization + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") + ) ++ Seq( + // SPARK-5270 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") + ) ++ Seq( + // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.getCheckpointFiles"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-5814 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") + ) ++ Seq( + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") + ) - case v if v.startsWith("1.2") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ - Seq( - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.TaskLocation"), - // Added normL1 and normL2 to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), - // MapStatus should be private[spark] - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.PathResolver"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.client.BlockClientListener"), + case v if v.startsWith("1.2") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.TaskLocation"), + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), + // MapStatus should be private[spark] + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), - // TaskContext was promoted to Abstract class - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.util.collection.SortDataFormat") - ) ++ Seq( - // Adding new methods to the JavaRDDLike trait: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.takeAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.collectAsync") - ) ++ Seq( - // SPARK-3822 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-1209 - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.rdd.PairRDDFunctions") - ) ++ Seq( - // SPARK-4062 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") - ) + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") + ) ++ Seq( + // SPARK-3822 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") + ) - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // Should probably mark this as Experimental - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.getValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry") - ) ++ - Seq( - // Serializer interface change. See SPARK-3045. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.DeserializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.Serializer"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializerInstance") - )++ - Seq( - // Renamed putValues -> putArray + putIterator - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.TachyonStore.putValues") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.flume.FlumeReceiver.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - // Class was missing "@DeveloperApi" annotation in 1.0. - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) ++ - Seq( // Package-private classes removed in SPARK-2341 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ - Seq( // package-private classes removed in MLlib - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") - ) ++ - Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") - ) ++ - Seq( // synthetic methods generated in LabeledPoint - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") - ) ++ - Seq ( // Scala 2.11 compatibility fix - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // Should probably mark this as Experimental + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.getValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ + Seq( + // Renamed putValues -> putArray + putIterator + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.TachyonStore.putValues") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + // Class was missing "@DeveloperApi" annotation in 1.0. + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce..b1dcaedcba75 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -16,15 +16,16 @@ */ import java.io._ +import java.nio.file.Files import scala.util.Properties -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings import spray.revolver.RevolverPlugin._ @@ -35,18 +36,19 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, + streamingKinesisAsl, dockerIntegrationTests) = + Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", + "docker-integration-tests").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -55,6 +57,9 @@ object BuildCommons { val sparkHome = buildLocation val testTempDir = s"$sparkHome/target/tmp" + + val javacJVMVersion = settingKey[String]("source and target JVM version for javac") + val scalacJVMVersion = settingKey[String]("source and target JVM version for scalac") } object SparkBuild extends PomBuild { @@ -67,7 +72,6 @@ object SparkBuild extends PomBuild { // Provides compatibility for older versions of the Spark build def backwardCompatibility = { import scala.collection.mutable - var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq("sbt") // scalastyle:off println if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { @@ -80,7 +84,6 @@ object SparkBuild extends PomBuild { } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => - if (v.matches("0.23.*")) isAlphaYarn = true println("NOTE: SPARK_HADOOP_VERSION is deprecated, please use -Dhadoop.version=" + v) System.setProperty("hadoop.version", v) case None => @@ -120,7 +123,7 @@ object SparkBuild extends PomBuild { case _ => } - override val userPropertiesMap = System.getProperties.toMap + override val userPropertiesMap = System.getProperties.asScala.toMap lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -135,8 +138,6 @@ object SparkBuild extends PomBuild { .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), incOptions := incOptions.value.withNameHashing(true), - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", @@ -154,13 +155,30 @@ object SparkBuild extends PomBuild { if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty }, - javacOptions in Compile ++= Seq("-encoding", "UTF-8"), + javacJVMVersion := "1.7", + scalacJVMVersion := "1.7", + + javacOptions in Compile ++= Seq( + "-encoding", "UTF-8", + "-source", javacJVMVersion.value + ), + // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't + // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for + // additional discussion and explanation. + javacOptions in (Compile, compile) ++= Seq( + "-target", javacJVMVersion.value + ), + + scalacOptions in Compile ++= Seq( + s"-target:jvm-${scalacJVMVersion.value}", + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc + ), // Implements -Xfatal-warnings, ignoring deprecation warnings. // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. compile in Compile := { val analysis = (compile in Compile).value - val s = streams.value + val out = streams.value def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = { l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message) @@ -177,7 +195,14 @@ object SparkBuild extends PomBuild { failed = failed + 1 } - logProblem(if (deprecation) s.log.warn else s.log.error, k, p) + val printer: (=> String) => Unit = s => if (deprecation) { + out.log.warn(s) + } else { + out.log.error("[warn] " + s) + } + + logProblem(printer, k, p) + } } @@ -196,13 +221,14 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ + ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe, testTags).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -212,6 +238,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -226,6 +255,9 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + enable(Java8TestSettings.settings)(java8Tests) + + enable(DockerIntegrationTests.settings)(dockerIntegrationTests) /** * Adds the ability to run the spark shell directly from SBT without building an assembly @@ -277,6 +309,21 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +object DockerIntegrationTests { + // This serves to override the override specified in DependencyOverrides: + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "18.0" + ) +} + +/** + * Overrides to work around sbt's dependency resolution being different from Maven's. + */ +object DependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") +} + /** This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. @@ -302,9 +349,7 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", - scalaVersion := "2.10.4", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", @@ -316,6 +361,8 @@ object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -325,9 +372,14 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } @@ -337,8 +389,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -346,6 +396,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -375,6 +426,8 @@ object Assembly { val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") + val deployDatanucleusJars = taskKey[Unit]("Deploy datanucleus jars to the spark/lib_managed/jars directory") + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, hadoopVersion := { @@ -382,13 +435,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -397,7 +453,20 @@ object Assembly { case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first - } + }, + deployDatanucleusJars := { + val jars: Seq[File] = (fullClasspath in assembly).value.map(_.data) + .filter(_.getPath.contains("org.datanucleus")) + var libManagedJars = new File(BuildCommons.sparkHome, "lib_managed/jars") + libManagedJars.mkdirs() + jars.foreach { jar => + val dest = new File(libManagedJars, jar.getName) + if (!dest.exists()) { + Files.copy(jar.toPath, dest.toPath) + } + } + }, + assembly <<= assembly.dependsOn(deployDatanucleusJars) ) } @@ -466,6 +535,8 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) @@ -477,13 +548,15 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } + val unidocSourceBase = settingKey[String]("Base URL of source links in Scaladoc.") + lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. @@ -519,8 +592,29 @@ object Unidoc { "-noqualifier", "java.lang" ), - // Group similar methods together based on the @group annotation. - scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") + // Use GitHub repository for Scaladoc source linke + unidocSourceBase := s"https://github.com/apache/spark/tree/v${version.value}", + + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( + "-groups" // Group similar methods together based on the @group annotation. + ) ++ ( + // Add links to sources when generating Scaladoc for a non-snapshot release + if (!isSnapshot.value) { + Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH}.scala") + } else { + Seq() + } + ) + ) +} + +object Java8TestSettings { + import BuildCommons._ + + lazy val settings = Seq( + javacJVMVersion := "1.8", + // Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher: + scalacJVMVersion := (if (System.getProperty("scala-2.11") == "true") "1.8" else "1.7") ) } @@ -535,28 +629,39 @@ object TestSettings { envVars in Test ++= Map( "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "SPARK_PREPEND_CLASSES" -> "1", + "SPARK_TESTING" -> "1", "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", - javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") + javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark")) .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", + // Exclude tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.exclude.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.exclude.tags").map { tags => + Seq("--exclude-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/project/build.properties b/project/build.properties index 064ec843da9e..86ca8755820a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.7 +sbt.version=0.13.9 diff --git a/project/plugins.sbt b/project/plugins.sbt index 51820460ca1a..5e23224cf8aa 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,5 +1,3 @@ -scalaVersion := "2.10.4" - resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" @@ -12,14 +10,9 @@ addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") -// For Sonatype publishing -//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) - -//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 471d00bd8223..cbb88dc7dd1d 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -19,9 +19,8 @@ import sbt._ import sbt.Keys._ /** - * This plugin project is there to define new scala style rules for spark. This is - * a plugin project so that this gets compiled first and is put on the classpath and - * becomes available for scalastyle sbt plugin. + * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is + * a plugin project so that this gets compiled first and is available on the classpath for SBT build. */ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) diff --git a/python/docs/Makefile b/python/docs/Makefile index 8a1324eecd32..4cec74f057fb 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.8.2.1-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/docs/_static/pyspark.css b/python/docs/_static/pyspark.css new file mode 100644 index 000000000000..41106f2f6e26 --- /dev/null +++ b/python/docs/_static/pyspark.css @@ -0,0 +1,90 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +body { + background-color: #ffffff; +} + +div.sphinxsidebar { + width: 274px; +} + +div.bodywrapper { + margin: 0 0 0 274px; +} + +div.sphinxsidebar ul { + margin-right: 10px; +} + +div.sphinxsidebar li a { + word-break: break-all; +} + +span.pys-tag { + font-size: 11px; + font-weight: bold; + margin: 0 0 0 2px; + padding: 1px 3px 1px 3px; + -moz-border-radius: 3px; + -webkit-border-radius: 3px; + border-radius: 3px; + text-align: center; + text-decoration: none; +} + +span.pys-tag-experimental { + background-color: rgb(37, 112, 128); + color: rgb(255, 255, 255); +} + +span.pys-tag-deprecated { + background-color: rgb(238, 238, 238); + color: rgb(62, 67, 73); +} + +div.pys-note-experimental { + background-color: rgb(88, 151, 165); + border-color: rgb(59, 115, 127); + color: rgb(255, 255, 255); +} + +div.pys-note-deprecated { +} + +.hasTooltip { + position:relative; +} +.hasTooltip span { + display:none; +} + +.hasTooltip:hover span.tooltip { + display: inline-block; + -moz-border-radius: 2px; + -webkit-border-radius: 2px; + border-radius: 2px; + background-color: rgb(250, 250, 250); + color: rgb(68, 68, 68); + font-weight: normal; + box-shadow: 1px 1px 3px rgb(127, 127, 127); + position: absolute; + padding: 0 3px 0 3px; + top: 1.3em; + left: 14px; + z-index: 9999 +} diff --git a/python/docs/_static/pyspark.js b/python/docs/_static/pyspark.js new file mode 100644 index 000000000000..75e4c42492a4 --- /dev/null +++ b/python/docs/_static/pyspark.js @@ -0,0 +1,99 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +$(function (){ + + function startsWith(s, prefix) { + return s && s.indexOf(prefix) === 0; + } + + function buildSidebarLinkMap() { + var linkMap = {}; + $('div.sphinxsidebar a.reference.internal').each(function (i,a) { + var href = $(a).attr('href'); + if (startsWith(href, '#module-')) { + var id = href.substr(8); + linkMap[id] = [$(a), null]; + } + }) + return linkMap; + }; + + function getAdNoteDivs(dd) { + var noteDivs = {}; + dd.find('> div.admonition.note > p.last').each(function (i, p) { + var text = $(p).text(); + if (!noteDivs.experimental && startsWith(text, 'Experimental')) { + noteDivs.experimental = $(p).parent(); + } + if (!noteDivs.deprecated && startsWith(text, 'Deprecated')) { + noteDivs.deprecated = $(p).parent(); + } + }); + return noteDivs; + } + + function getParentId(name) { + var last_idx = name.lastIndexOf('.'); + return last_idx == -1? '': name.substr(0, last_idx); + } + + function buildTag(text, cls, tooltip) { + return '' + text + '' + + tooltip + '' + } + + + var sidebarLinkMap = buildSidebarLinkMap(); + + $('dl.class, dl.function').each(function (i,dl) { + + dl = $(dl); + dt = dl.children('dt').eq(0); + dd = dl.children('dd').eq(0); + var id = dt.attr('id'); + var desc = dt.find('> .descname').text(); + var adNoteDivs = getAdNoteDivs(dd); + + if (id) { + var parent_id = getParentId(id); + + var r = sidebarLinkMap[parent_id]; + if (r) { + if (r[1] === null) { + r[1] = $('

        '); + r[0].parent().append(r[1]); + } + var tags = ''; + if (adNoteDivs.experimental) { + tags += buildTag('E', 'pys-tag-experimental', 'Experimental'); + adNoteDivs.experimental.addClass('pys-note pys-note-experimental'); + } + if (adNoteDivs.deprecated) { + tags += buildTag('D', 'pys-tag-deprecated', 'Deprecated'); + adNoteDivs.deprecated.addClass('pys-note pys-note-deprecated'); + } + var li = $('
      • '); + var a = $('' + desc + ''); + li.append(a); + li.append(tags); + r[1].append(li); + sidebarLinkMap[id] = [a, null]; + } + } + }); +}); diff --git a/python/docs/_templates/layout.html b/python/docs/_templates/layout.html new file mode 100644 index 000000000000..ab36ebababf8 --- /dev/null +++ b/python/docs/_templates/layout.html @@ -0,0 +1,6 @@ +{% extends "!layout.html" %} +{% set script_files = script_files + ["_static/pyspark.js"] %} +{% set css_files = css_files + ['_static/pyspark.css'] %} +{% block rootrellink %} + {{ super() }} +{% endblock %} diff --git a/python/docs/conf.py b/python/docs/conf.py index 163987dd8e5f..365d6af51417 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -23,7 +23,7 @@ # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +needs_sphinx = '1.2' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -135,7 +135,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/python/docs/index.rst b/python/docs/index.rst index f7eede9c3c82..306ffdb0e0f1 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.streaming.StreamingContext` + + Main entry point for Spark Streaming functionality. + + :class:`pyspark.streaming.DStream` + + A Discretized Stream (DStream), the basic abstraction in Spark Streaming. + :class:`pyspark.sql.SQLContext` Main entry point for DataFrame and SQL functionality. diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 26ece4c2c389..2d54ab118b94 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -46,6 +46,14 @@ pyspark.mllib.linalg module :undoc-members: :show-inheritance: +pyspark.mllib.linalg.distributed module +--------------------------------------- + +.. automodule:: pyspark.mllib.linalg.distributed + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.random module --------------------------- diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 50822c93faba..fc52a647543e 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -15,3 +15,24 @@ pyspark.streaming.kafka module :members: :undoc-members: :show-inheritance: + +pyspark.streaming.kinesis module +-------------------------------- +.. automodule:: pyspark.streaming.kinesis + :members: + :undoc-members: + :show-inheritance: + +pyspark.streaming.flume.module +------------------------------ +.. automodule:: pyspark.streaming.flume + :members: + :undoc-members: + :show-inheritance: + +pyspark.streaming.mqtt module +----------------------------- +.. automodule:: pyspark.streaming.mqtt + :members: + :undoc-members: + :show-inheritance: diff --git a/python/lib/py4j-0.8.2.1-src.zip b/python/lib/py4j-0.8.2.1-src.zip deleted file mode 100644 index 5203b84d9119..000000000000 Binary files a/python/lib/py4j-0.8.2.1-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.9-src.zip b/python/lib/py4j-0.9-src.zip new file mode 100644 index 000000000000..dace2d0fe3b0 Binary files /dev/null and b/python/lib/py4j-0.9-src.zip differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f70ac6ed8fe..8475dfb1c6ad 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -48,6 +48,22 @@ from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + + # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 3b647985801b..95b3abc74244 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,6 +350,11 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ + # workaround for namedtuple (hijacked by PySpark) + if getattr(obj, '_is_namedtuple_', False): + self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) + return + self.save(_load_class) self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) d.pop('__doc__', None) @@ -382,7 +387,7 @@ def save_instancemethod(self, obj): self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -744,6 +749,14 @@ def _load_class(cls, d): return cls +def _load_namedtuple(name, fields): + """ + Loads a class generated by namedtuple + """ + from collections import namedtuple + return namedtuple(name, fields) + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index eb5b0bbbdac4..529d16b48039 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -19,8 +19,10 @@ import os import shutil +import signal import sys -from threading import Lock +import threading +from threading import RLock from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -64,7 +66,7 @@ class SparkContext(object): _jvm = None _next_accum_id = 0 _active_spark_context = None - _lock = Lock() + _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -217,6 +219,15 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, else: self.profiler_collector = None + # create a signal handler which would be invoked on receiving SIGINT + def signal_handler(signal, frame): + self.cancelAllJobs() + raise KeyboardInterrupt() + + # see http://stackoverflow.com/questions/23206787/ + if isinstance(threading.current_thread(), threading._MainThread): + signal.signal(signal.SIGINT, signal_handler) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -255,7 +266,7 @@ def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( "It appears that you are attempting to reference SparkContext from a broadcast " - "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "variable, action, or transformation. SparkContext can only be used on the driver, " "not in code that it run on workers. For more information, see SPARK-5063." ) @@ -273,6 +284,18 @@ def __exit__(self, type, value, trace): """ self.stop() + @classmethod + def getOrCreate(cls, conf=None): + """ + Get or instantiate a SparkContext and register it as a singleton object. + + :param conf: SparkConf (optional) + """ + with SparkContext._lock: + if SparkContext._active_spark_context is None: + SparkContext(conf=conf or SparkConf()) + return SparkContext._active_spark_context + def setLogLevel(self, logLevel): """ Control our logLevel. This overrides any user-defined log settings. @@ -302,10 +325,10 @@ def applicationId(self): """ A unique identifier for the Spark application. Its format depends on the scheduler implementation. - (i.e. - in case of local spark app something like 'local-1433865536131' - in case of YARN something like 'application_1433865536131_34483' - ) + + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + >>> sc.applicationId # doctest: +ELLIPSIS u'local-...' """ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 60be85e53e2a..cd4c55f79f18 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -54,7 +54,6 @@ def launch_gateway(): if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", - "--conf spark.buffer.pageSize=4mb", submit_args ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b5814f76de00..5599b8f3ecd8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,34 +15,41 @@ # limitations under the License. # +import warnings + +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.ml.regression import ( - RandomForestParams, DecisionTreeModel, TreeEnsembleModels) + RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel'] + 'NaiveBayesModel', 'MultilayerPerceptronClassifier', + 'MultilayerPerceptronClassificationModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) - >>> model.weights + >>> model.coefficients DenseVector([5.5...]) >>> model.intercept -2.68... @@ -61,114 +68,155 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.3.0 """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1].") + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1].") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only + @since("1.3.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. + If the threshold and thresholds Params are both set, they must be equivalent. """ kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. """ self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. """ - return self.getOrDefault(self.threshold) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) + + @since("1.5.0") + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. + """ + self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] + return self + + @since("1.5.0") + def getThresholds(self): + """ + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. + """ + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) + + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. + + .. versionadded:: 1.3.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. """ + + warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -179,13 +227,45 @@ def intercept(self): class TreeClassifierParams(object): """ Private class to track supported impurity measures. + + .. versionadded:: 1.4.0 """ supportedImpurities = ["entropy", "gini"] + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + + ", ".join(supportedImpurities)) + + def __init__(self): + super(TreeClassifierParams, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = Param(self, "impurity", "Criterion used for information " + + "gain calculation (case-insensitive). Supported options: " + + ", ".join(self.supportedImpurities)) + + @since("1.6.0") + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + @since("1.6.0") + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + -class GBTParams(object): +class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. + + .. versionadded:: 1.4.0 """ supportedLossTypes = ["logistic"] @@ -193,7 +273,7 @@ class GBTParams(object): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - HasCheckpointInterval): + TreeClassifierParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -225,12 +305,9 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 - """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + .. versionadded:: 1.4.0 + """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -246,11 +323,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(DecisionTreeClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") @@ -258,6 +330,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -276,36 +349,27 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeClassificationModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - @inherit_doc class DecisionTreeClassificationModel(DecisionTreeModel): """ Model fitted by DecisionTreeClassifier. + + .. versionadded:: 1.4.0 """ @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - DecisionTreeParams, HasCheckpointInterval): + HasRawPredictionCol, HasProbabilityCol, + RandomForestParams, TreeClassifierParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for classification. It supports both binary and multiclass labels, as well as both continuous and categorical features. + >>> import numpy >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer @@ -320,33 +384,29 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) - >>> model.transform(test0).head().prediction + >>> result = model.transform(test0).head() + >>> result.prediction 0.0 + >>> numpy.argmax(result.probability) + 0 + >>> numpy.argmax(result.rawPrediction) + 0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 - """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + .. versionadded:: 1.4.0 + """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=None) @@ -354,23 +414,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(RandomForestClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) - #: param for Fraction of the training data used for learning each decision tree, - # in range (0, 1] - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: param for Number of trees to train (>= 1) - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") - #: param for The number of features to consider for splits at each tree node - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto") @@ -378,12 +421,15 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto") @@ -395,68 +441,18 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - def setNumTrees(self, value): - """ - Sets the value of :py:attr:`numTrees`. - """ - self._paramMap[self.numTrees] = value - return self - - def getNumTrees(self): - """ - Gets the value of numTrees or its default value. - """ - return self.getOrDefault(self.numTrees) - - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - self._paramMap[self.featureSubsetStrategy] = value - return self - - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class RandomForestClassificationModel(TreeEnsembleModels): """ Model fitted by RandomForestClassifier. + + .. versionadded:: 1.4.0 """ @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - DecisionTreeParams, HasCheckpointInterval): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -482,18 +478,14 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + - "contribution of each estimator") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -513,15 +505,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - #: Fraction of the training data used for learning each decision tree, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of - # each estimator - self.stepSize = Param(self, "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator") self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1) @@ -529,6 +512,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -546,6 +530,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTClassificationModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -553,42 +538,19 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - def setStepSize(self, value): - """ - Sets the value of :py:attr:`stepSize`. - """ - self._paramMap[self.stepSize] = value - return self - - def getStepSize(self): - """ - Gets the value of stepSize or its default value. - """ - return self.getOrDefault(self.stepSize) - class GBTClassificationModel(TreeEnsembleModels): """ Model fitted by GBTClassifier. + + .. versionadded:: 1.4.0 """ @@ -597,6 +559,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H HasRawPredictionCol): """ Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors @@ -621,6 +590,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -653,6 +624,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, modelType="multinomial"): @@ -668,6 +640,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return NaiveBayesModel(java_model) + @since("1.5.0") def setSmoothing(self, value): """ Sets the value of :py:attr:`smoothing`. @@ -675,12 +648,14 @@ def setSmoothing(self, value): self._paramMap[self.smoothing] = value return self + @since("1.5.0") def getSmoothing(self): """ Gets the value of smoothing or its default value. """ return self.getOrDefault(self.smoothing) + @since("1.5.0") def setModelType(self, value): """ Sets the value of :py:attr:`modelType`. @@ -688,6 +663,7 @@ def setModelType(self, value): self._paramMap[self.modelType] = value return self + @since("1.5.0") def getModelType(self): """ Gets the value of modelType or its default value. @@ -698,9 +674,12 @@ def getModelType(self): class NaiveBayesModel(JavaModel): """ Model fitted by NaiveBayes. + + .. versionadded:: 1.5.0 """ @property + @since("1.5.0") def pi(self): """ log of class priors. @@ -708,6 +687,7 @@ def pi(self): return self._call_java("pi") @property + @since("1.5.0") def theta(self): """ log of class conditional probabilities. @@ -715,6 +695,146 @@ def theta(self): return self._call_java("theta") +@inherit_doc +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasSeed): + """ + Classifier trainer based on the Multilayer Perceptron. + Each layer has sigmoid activation function, output layer has softmax. + Number of inputs has to be equal to the size of feature vectors. + Number of outputs has to be equal to the total number of labels. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (0.0, Vectors.dense([0.0, 0.0])), + ... (1.0, Vectors.dense([0.0, 1.0])), + ... (1.0, Vectors.dense([1.0, 0.0])), + ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> model = mlp.fit(df) + >>> model.layers + [2, 5, 2] + >>> model.weights.size + 27 + >>> testDF = sqlContext.createDataFrame([ + ... (Vectors.dense([1.0, 0.0]),), + ... (Vectors.dense([0.0, 0.0]),)], ["features"]) + >>> model.transform(testDF).show() + +---------+----------+ + | features|prediction| + +---------+----------+ + |[1.0,0.0]| 1.0| + |[0.0,0.0]| 0.0| + +---------+----------+ + ... + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + + "neurons and output layer of 10 neurons, default is [1, 1].") + blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is more than " + + "remaining data in a partition then it is adjusted to the size of this " + + "data. Recommended size is between 10 and 1000, default is 128.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + """ + super(MultilayerPerceptronClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) + self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + + "100 neurons and output layer of 10 neurons, default is [1, 1].") + self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is " + + "more than remaining data in a partition then it is adjusted to " + + "the size of this data. Recommended size is between 10 and 1000, " + + "default is 128.") + self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + Sets params for MultilayerPerceptronClassifier. + """ + kwargs = self.setParams._input_kwargs + if layers is None: + return self._set(**kwargs).setLayers([1, 1]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return MultilayerPerceptronClassificationModel(java_model) + + @since("1.6.0") + def setLayers(self, value): + """ + Sets the value of :py:attr:`layers`. + """ + self._paramMap[self.layers] = value + return self + + @since("1.6.0") + def getLayers(self): + """ + Gets the value of layers or its default value. + """ + return self.getOrDefault(self.layers) + + @since("1.6.0") + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + self._paramMap[self.blockSize] = value + return self + + @since("1.6.0") + def getBlockSize(self): + """ + Gets the value of blockSize or its default value. + """ + return self.getOrDefault(self.blockSize) + + +class MultilayerPerceptronClassificationModel(JavaModel): + """ + Model fitted by MultilayerPerceptronClassifier. + + .. versionadded:: 1.6.0 + """ + + @property + @since("1.6.0") + def layers(self): + """ + array of layer sizes including input and output layers. + """ + return self._call_java("javaLayers") + + @property + @since("1.6.0") + def weights(self): + """ + vector of initial weights for the model that consists of the weights of layers. + """ + return self._call_java("weights") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b5e9b6549d9f..7bb8ab94e17d 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,11 +15,11 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector __all__ = ['KMeans', 'KMeansModel'] @@ -27,23 +27,28 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): """ - K-means Clustering + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. >>> from pyspark.mllib.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = sqlContext.createDataFrame(data, ["features"]) - >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() >>> len(centers) @@ -54,14 +59,12 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") - epsilon = Param(Params._dummy(), "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + @@ -69,21 +72,21 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") @keyword_only - def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self.k = Param(self, "k", "number of clusters to create") - self.epsilon = Param(self, "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") - self.seed = Param(self, "seed", "random seed") self.initMode = Param(self, "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++") self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") - self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -91,15 +94,19 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only - def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + @since("1.5.0") + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ - setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) Sets params for KMeans. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -111,46 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) - def setEpsilon(self, value): - """ - Sets the value of :py:attr:`epsilon`. - - >>> algo = KMeans().setEpsilon(1e-5) - >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 - True - """ - self._paramMap[self.epsilon] = value - return self - - def getEpsilon(self): - """ - Gets the value of `epsilon` - """ - return self.getOrDefault(self.epsilon) - - def setRuns(self, value): - """ - Sets the value of :py:attr:`runs`. - - >>> algo = KMeans().setRuns(10) - >>> algo.getRuns() - 10 - """ - self._paramMap[self.runs] = value - return self - - def getRuns(self): - """ - Gets the value of `runs` - """ - return self.getOrDefault(self.runs) - + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -165,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -182,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 06e809352225..dcc1738ec518 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -17,19 +17,23 @@ from abc import abstractmethod, ABCMeta +from pyspark import since from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] @inherit_doc class Evaluator(Params): """ Base class for evaluators that compute metrics from predictions. + + .. versionadded:: 1.4.0 """ __metaclass__ = ABCMeta @@ -45,7 +49,8 @@ def _evaluate(self, dataset): """ raise NotImplementedError() - def evaluate(self, dataset, params={}): + @since("1.4.0") + def evaluate(self, dataset, params=None): """ Evaluates the output with optional parameters. @@ -55,6 +60,8 @@ def evaluate(self, dataset, params={}): params :return: metric """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params)._evaluate(dataset) @@ -63,6 +70,15 @@ def evaluate(self, dataset, params={}): else: raise ValueError("Params must be a param map but got %s." % type(params)) + @since("1.5.0") + def isLargerBetter(self): + """ + Indicates whether the metric returned by :py:meth:`evaluate` should be maximized + (True, default) or minimized (False). + A given evaluator may support multiple metrics which may be maximized or minimized. + """ + return True + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): @@ -82,6 +98,10 @@ def _evaluate(self, dataset): self._transfer_params_to_java() return self._java_obj.evaluate(dataset._jdf) + def isLargerBetter(self): + self._transfer_params_to_java() + return self._java_obj.isLargerBetter() + @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): @@ -99,6 +119,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) 0.83... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -123,6 +145,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.4.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -130,6 +153,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.4.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -137,6 +161,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.4.0") def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC"): """ @@ -160,11 +185,13 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): ... >>> evaluator = RegressionEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) - -2.842... + 2.842... >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) - -2.649... + 2.649... + + .. versionadded:: 1.4.0 """ # Because we will maximize evaluation value (ref: `CrossValidator`), # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), @@ -190,6 +217,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.4.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -197,6 +225,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.4.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -204,6 +233,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.4.0") def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse"): """ @@ -231,6 +261,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) 0.66... + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", @@ -256,6 +288,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.5.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -263,6 +296,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.5.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -270,6 +304,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.5.0") def setParams(self, predictionCol="prediction", labelCol="label", metricName="f1"): """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 015e7a9d4900..b02d41b52ab2 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,21 +15,32 @@ # limitations under the License. # +import sys +if sys.version > '3': + basestring = str + +from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', - 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Binarize a column of continuous features given a threshold. >>> df = sqlContext.createDataFrame([(0.5,)], ["values"]) @@ -41,6 +52,8 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"} >>> binarizer.transform(df, params).head().vector 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -61,6 +74,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, threshold=0.0, inputCol=None, outputCol=None): """ setParams(self, threshold=0.0, inputCol=None, outputCol=None) @@ -69,6 +83,7 @@ def setParams(self, threshold=0.0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -76,6 +91,7 @@ def setThreshold(self, value): self._paramMap[self.threshold] = value return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. @@ -86,6 +102,8 @@ def getThreshold(self): @inherit_doc class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Maps a column of continuous features to a column of feature buckets. >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) @@ -102,6 +120,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 + + .. versionadded:: 1.3.0 """ # a placeholder to make it appear in the generated doc @@ -138,6 +158,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, splits=None, inputCol=None, outputCol=None): """ setParams(self, splits=None, inputCol=None, outputCol=None) @@ -146,6 +167,7 @@ def setParams(self, splits=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setSplits(self, value): """ Sets the value of :py:attr:`splits`. @@ -153,6 +175,7 @@ def setSplits(self, value): self._paramMap[self.splits] = value return self + @since("1.4.0") def getSplits(self): """ Gets the value of threshold or its default value. @@ -160,9 +183,293 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + + .. versionadded:: 1.6.0 + """ + + @property + @since("1.6.0") + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + +@inherit_doc +class DCT(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that takes the 1D discrete cosine transform + of a real vector. No zero padding is performed on the input vector. + It returns a real vector of the same length representing the DCT. + The return vector is scaled such that the transform matrix is + unitary (aka scaled DCT-II). + + More information on + `https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia`. + + >>> from pyspark.mllib.linalg import Vectors + >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) + >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec") + >>> df2 = dct.transform(df1) + >>> df2.head().resultVec + DenseVector([10.969..., -0.707..., -2.041...]) + >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2) + >>> df3.head().origVec + DenseVector([5.0, 8.0, 6.0]) + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + + @keyword_only + def __init__(self, inverse=False, inputCol=None, outputCol=None): + """ + __init__(self, inverse=False, inputCol=None, outputCol=None) + """ + super(DCT, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) + self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + self._setDefault(inverse=False) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, inverse=False, inputCol=None, outputCol=None): + """ + setParams(self, inverse=False, inputCol=None, outputCol=None) + Sets params for this DCT. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setInverse(self, value): + """ + Sets the value of :py:attr:`inverse`. + """ + self._paramMap[self.inverse] = value + return self + + @since("1.6.0") + def getInverse(self): + """ + Gets the value of inverse or its default value. + """ + return self.getOrDefault(self.inverse) + + +@inherit_doc +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Outputs the Hadamard product (i.e., the element-wise product) of each input vector + with a provided "weight" vector. In other words, it scales each column of the dataset + by a scalar multiplier. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) + >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]), + ... inputCol="values", outputCol="eprod") + >>> ep.transform(df).head().eprod + DenseVector([2.0, 2.0, 9.0]) + >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod + DenseVector([4.0, 3.0, 15.0]) + + .. versionadded:: 1.5.0 + """ + + # a placeholder to make it appear in the generated doc + scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + + @keyword_only + def __init__(self, scalingVec=None, inputCol=None, outputCol=None): + """ + __init__(self, scalingVec=None, inputCol=None, outputCol=None) + """ + super(ElementwiseProduct, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", + self.uid) + self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.5.0") + def setParams(self, scalingVec=None, inputCol=None, outputCol=None): + """ + setParams(self, scalingVec=None, inputCol=None, outputCol=None) + Sets params for this ElementwiseProduct. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.5.0") + def setScalingVec(self, value): + """ + Sets the value of :py:attr:`scalingVec`. + """ + self._paramMap[self.scalingVec] = value + return self + + @since("1.5.0") + def getScalingVec(self): + """ + Gets the value of scalingVec or its default value. + """ + return self.getOrDefault(self.scalingVec) + + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ + .. note:: Experimental + Maps a sequence of terms to their term frequencies using the hashing trick. @@ -175,6 +482,8 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} >>> hashingTF.transform(df, params).head().vector SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + + .. versionadded:: 1.3.0 """ @keyword_only @@ -189,6 +498,7 @@ def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.3.0") def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) @@ -201,6 +511,8 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Compute the Inverse Document Frequency (IDF) given a collection of documents. >>> from pyspark.mllib.linalg import DenseVector @@ -214,6 +526,8 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): >>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"} >>> idf.fit(df, params).transform(df).head().vector DenseVector([0.2877, 0.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -234,6 +548,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): """ setParams(self, minDocFreq=0, inputCol=None, outputCol=None) @@ -242,6 +557,7 @@ def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMinDocFreq(self, value): """ Sets the value of :py:attr:`minDocFreq`. @@ -249,6 +565,7 @@ def setMinDocFreq(self, value): self._paramMap[self.minDocFreq] = value return self + @since("1.4.0") def getMinDocFreq(self): """ Gets the value of minDocFreq or its default value. @@ -261,7 +578,114 @@ def _create_model(self, java_model): class IDFModel(JavaModel): """ + .. note:: Experimental + Model fitted by IDF. + + .. versionadded:: 1.4.0 + """ + + +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to a common range [min, max] linearly using column summary + statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + feature E is calculated as, + + Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min + + For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) + + Note that since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> model = mmScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[0.0]| [0.0]| + |[2.0]| [1.0]| + +-----+------+ + ... + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + min = Param(Params._dummy(), "min", "Lower bound of the output feature range") + max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + + @keyword_only + def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + """ + super(MinMaxScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) + self.min = Param(self, "min", "Lower bound of the output feature range") + self.max = Param(self, "max", "Upper bound of the output feature range") + self._setDefault(min=0.0, max=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + Sets params for this MinMaxScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + self._paramMap[self.min] = value + return self + + @since("1.6.0") + def getMin(self): + """ + Gets the value of min or its default value. + """ + return self.getOrDefault(self.min) + + @since("1.6.0") + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + self._paramMap[self.max] = value + return self + + @since("1.6.0") + def getMax(self): + """ + Gets the value of max or its default value. + """ + return self.getOrDefault(self.max) + + def _create_model(self, java_model): + return MinMaxScalerModel(java_model) + + +class MinMaxScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MinMaxScaler`. + + .. versionadded:: 1.6.0 """ @@ -269,6 +693,8 @@ class IDFModel(JavaModel): @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A feature transformer that converts the input array of strings into an array of n-grams. Null values in the input array are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string of @@ -294,6 +720,8 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -312,6 +740,7 @@ def __init__(self, n=2, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, n=2, inputCol=None, outputCol=None): """ setParams(self, n=2, inputCol=None, outputCol=None) @@ -320,6 +749,7 @@ def setParams(self, n=2, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setN(self, value): """ Sets the value of :py:attr:`n`. @@ -327,6 +757,7 @@ def setN(self, value): self._paramMap[self.n] = value return self + @since("1.5.0") def getN(self): """ Gets the value of n or its default value. @@ -337,6 +768,8 @@ def getN(self): @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Normalize a vector to have unit norm using the given p-norm. >>> from pyspark.mllib.linalg import Vectors @@ -350,6 +783,8 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"} >>> normalizer.transform(df, params).head().vector DenseVector([0.4286, -0.5714]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -368,6 +803,7 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, p=2.0, inputCol=None, outputCol=None): """ setParams(self, p=2.0, inputCol=None, outputCol=None) @@ -376,6 +812,7 @@ def setParams(self, p=2.0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setP(self, value): """ Sets the value of :py:attr:`p`. @@ -383,6 +820,7 @@ def setP(self, value): self._paramMap[self.p] = value return self + @since("1.4.0") def getP(self): """ Gets the value of p or its default value. @@ -393,6 +831,8 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index. @@ -422,6 +862,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -440,6 +882,7 @@ def __init__(self, dropLast=True, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ setParams(self, dropLast=True, inputCol=None, outputCol=None) @@ -448,6 +891,7 @@ def setParams(self, dropLast=True, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setDropLast(self, value): """ Sets the value of :py:attr:`dropLast`. @@ -455,6 +899,7 @@ def setDropLast(self, value): self._paramMap[self.dropLast] = value return self + @since("1.4.0") def getDropLast(self): """ Gets the value of dropLast or its default value. @@ -465,6 +910,8 @@ def getDropLast(self): @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an expansion of a product of sums expresses it as a sum of products by using the fact that @@ -478,6 +925,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) >>> px.setParams(outputCol="test").transform(df).head().test DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -497,6 +946,7 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, degree=2, inputCol=None, outputCol=None): """ setParams(self, degree=2, inputCol=None, outputCol=None) @@ -505,6 +955,7 @@ def setParams(self, degree=2, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setDegree(self, value): """ Sets the value of :py:attr:`degree`. @@ -512,6 +963,7 @@ def setDegree(self, value): self._paramMap[self.degree] = value return self + @since("1.4.0") def getDegree(self): """ Gets the value of degree or its default value. @@ -523,6 +975,8 @@ def getDegree(self): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text (default) or repeatedly matching the regex (if gaps is false). @@ -547,6 +1001,8 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -569,6 +1025,7 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, o self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) @@ -577,6 +1034,7 @@ def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMinTokenLength(self, value): """ Sets the value of :py:attr:`minTokenLength`. @@ -584,12 +1042,14 @@ def setMinTokenLength(self, value): self._paramMap[self.minTokenLength] = value return self + @since("1.4.0") def getMinTokenLength(self): """ Gets the value of minTokenLength or its default value. """ return self.getOrDefault(self.minTokenLength) + @since("1.4.0") def setGaps(self, value): """ Sets the value of :py:attr:`gaps`. @@ -597,12 +1057,14 @@ def setGaps(self, value): self._paramMap[self.gaps] = value return self + @since("1.4.0") def getGaps(self): """ Gets the value of gaps or its default value. """ return self.getOrDefault(self.gaps) + @since("1.4.0") def setPattern(self, value): """ Sets the value of :py:attr:`pattern`. @@ -610,6 +1072,7 @@ def setPattern(self, value): self._paramMap[self.pattern] = value return self + @since("1.4.0") def getPattern(self): """ Gets the value of pattern or its default value. @@ -617,9 +1080,69 @@ def getPattern(self): return self.getOrDefault(self.pattern) +@inherit_doc +class SQLTransformer(JavaTransformer): + """ + .. note:: Experimental + + Implements the transforms which are defined by SQL statement. + Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + where '__THIS__' represents the underlying table of the input dataset. + + >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + >>> sqlTrans = SQLTransformer( + ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + >>> sqlTrans.transform(df).head() + Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0) + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + statement = Param(Params._dummy(), "statement", "SQL statement") + + @keyword_only + def __init__(self, statement=None): + """ + __init__(self, statement=None) + """ + super(SQLTransformer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) + self.statement = Param(self, "statement", "SQL statement") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, statement=None): + """ + setParams(self, statement=None) + Sets params for this SQLTransformer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setStatement(self, value): + """ + Sets the value of :py:attr:`statement`. + """ + self._paramMap[self.statement] = value + return self + + @since("1.6.0") + def getStatement(self): + """ + Gets the value of statement or its default value. + """ + return self.getOrDefault(self.statement) + + @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. @@ -633,6 +1156,8 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) >>> model.transform(df).collect()[1].scaled DenseVector([1.4142]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -653,6 +1178,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None): """ setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) @@ -661,6 +1187,7 @@ def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setWithMean(self, value): """ Sets the value of :py:attr:`withMean`. @@ -668,12 +1195,14 @@ def setWithMean(self, value): self._paramMap[self.withMean] = value return self + @since("1.4.0") def getWithMean(self): """ Gets the value of withMean or its default value. """ return self.getOrDefault(self.withMean) + @since("1.4.0") def setWithStd(self, value): """ Sets the value of :py:attr:`withStd`. @@ -681,6 +1210,7 @@ def setWithStd(self, value): self._paramMap[self.withStd] = value return self + @since("1.4.0") def getWithStd(self): """ Gets the value of withStd or its default value. @@ -693,10 +1223,15 @@ def _create_model(self, java_model): class StandardScalerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StandardScaler. + + .. versionadded:: 1.4.0 """ @property + @since("1.5.0") def std(self): """ Standard deviation of the StandardScalerModel. @@ -704,6 +1239,7 @@ def std(self): return self._call_java("std") @property + @since("1.5.0") def mean(self): """ Mean of the StandardScalerModel. @@ -712,8 +1248,10 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): """ + .. note:: Experimental + A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. The indices are in [0, numLabels), ordered by label frequencies. @@ -725,22 +1263,31 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> itd = inverter.transform(td) + >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + + .. versionadded:: 1.4.0 """ @keyword_only - def __init__(self, inputCol=None, outputCol=None): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCol=None, outputCol=None) + __init__(self, inputCol=None, outputCol=None, handleInvalid="error") """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) + self._setDefault(handleInvalid="error") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None): + @since("1.4.0") + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCol=None, outputCol=None) + setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ kwargs = self.setParams._input_kwargs @@ -752,14 +1299,162 @@ def _create_model(self, java_model): class StringIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StringIndexer. + + .. versionadded:: 1.4.0 """ + @property + @since("1.5.0") + def labels(self): + """ + Ordered list of labels, corresponding to indices to be assigned. + """ + return self._java_obj.labels + + +@inherit_doc +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A :py:class:`Transformer` that maps a column of indices back to a new column of + corresponding string values. + The index-string mapping is either from the ML attributes of the input column, + or from user-supplied labels (which take precedence over ML attributes). + See L{StringIndexer} for converting strings into indices. + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make the labels show up in generated doc + labels = Param(Params._dummy(), "labels", + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, labels=None): + """ + __init__(self, inputCol=None, outputCol=None, labels=None) + """ + super(IndexToString, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", + self.uid) + self.labels = Param(self, "labels", + "Optional array of labels specifying index-string mapping. If not" + + " provided or if empty, then metadata from inputCol is used instead.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, inputCol=None, outputCol=None, labels=None): + """ + setParams(self, inputCol=None, outputCol=None, labels=None) + Sets params for this IndexToString. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setLabels(self, value): + """ + Sets the value of :py:attr:`labels`. + """ + self._paramMap[self.labels] = value + return self + + @since("1.6.0") + def getLabels(self): + """ + Gets the value of :py:attr:`labels` or its default value. + """ + return self.getOrDefault(self.labels) + + +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that filters out stop words from input. + Note: null values from input array are preserved unless adding null to stopWords explicitly. + + .. versionadded:: 1.6.0 + """ + # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + + "comparison over the stop words") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + __init__(self, inputCol=None, outputCol=None, stopWords=None,\ + caseSensitive=false) + """ + super(StopWordsRemover, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", + self.uid) + self.stopWords = Param(self, "stopWords", "The words to be filtered out") + self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + + "sensitive comparison over the stop words") + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords + defaultStopWords = stopWordsObj.English() + self._setDefault(stopWords=defaultStopWords) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + setParams(self, inputCol="input", outputCol="output", stopWords=None,\ + caseSensitive=false) + Sets params for this StopWordRemover. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setStopWords(self, value): + """ + Specify the stopwords to be filtered. + """ + self._paramMap[self.stopWords] = value + return self + + @since("1.6.0") + def getStopWords(self): + """ + Get the stopwords. + """ + return self.getOrDefault(self.stopWords) + + @since("1.6.0") + def setCaseSensitive(self, value): + """ + Set whether to do a case sensitive comparison over the stop words + """ + self._paramMap[self.caseSensitive] = value + return self + + @since("1.6.0") + def getCaseSensitive(self): + """ + Get whether to do a case sensitive comparison over the stop words. + """ + return self.getOrDefault(self.caseSensitive) @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A tokenizer that converts the input string to lowercase and then splits it by white spaces. @@ -780,6 +1475,8 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.3.0 """ @keyword_only @@ -793,6 +1490,7 @@ def __init__(self, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.3.0") def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") @@ -805,6 +1503,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): """ + .. note:: Experimental + A feature transformer that merges multiple columns into a vector column. >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) @@ -816,6 +1516,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} >>> vecAssembler.transform(df, params).head().vector DenseVector([0.0, 1.0]) + + .. versionadded:: 1.4.0 """ @keyword_only @@ -829,6 +1531,7 @@ def __init__(self, inputCols=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, inputCols=None, outputCol=None): """ setParams(self, inputCols=None, outputCol=None) @@ -841,6 +1544,8 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Class for indexing categorical feature columns in a dataset of [[Vector]]. This has 2 usage modes: @@ -883,12 +1588,18 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> model = indexer.fit(df) >>> model.transform(df).head().indexed DenseVector([1.0, 0.0]) + >>> model.numFeatures + 2 + >>> model.categoryMaps + {0: {0.0: 0, -1.0: 1}} >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test DenseVector([0.0, 1.0]) >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"} >>> model2 = indexer.fit(df, params) >>> model2.transform(df).head().vector DenseVector([1.0, 0.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -913,6 +1624,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, maxCategories=20, inputCol=None, outputCol=None): """ setParams(self, maxCategories=20, inputCol=None, outputCol=None) @@ -921,6 +1633,7 @@ def setParams(self, maxCategories=20, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMaxCategories(self, value): """ Sets the value of :py:attr:`maxCategories`. @@ -928,6 +1641,7 @@ def setMaxCategories(self, value): self._paramMap[self.maxCategories] = value return self + @since("1.4.0") def getMaxCategories(self): """ Gets the value of maxCategories or its default value. @@ -940,22 +1654,157 @@ def _create_model(self, java_model): class VectorIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by VectorIndexer. + + .. versionadded:: 1.4.0 """ + @property + @since("1.4.0") + def numFeatures(self): + """ + Number of features, i.e., length of Vectors which this transforms. + """ + return self._call_java("numFeatures") + + @property + @since("1.4.0") + def categoryMaps(self): + """ + Feature value index. Keys are categorical feature indices (column indices). + Values are maps from original features values to 0-based category indices. + If a feature is not in this map, it is treated as continuous. + """ + return self._call_java("javaCategoryMaps") + + +@inherit_doc +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + This class takes a feature vector and outputs a new feature vector with a subarray + of the original features. + + The subset of features can be specified with either indices (`setIndices()`) + or names (`setNames()`). At least one feature must be selected. Duplicate features + are not allowed, so there can be no overlap between selected indices and names. + + The output vector will order features with the selected indices first (in the order given), + followed by the selected names (in the order given). + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),), + ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),), + ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"]) + >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) + >>> vs.transform(df).head().sliced + DenseVector([2.3, 1.0]) + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + names = Param(Params._dummy(), "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " + + "indices.") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + __init__(self, inputCol=None, outputCol=None, indices=None, names=None) + """ + super(VectorSlicer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) + self.indices = Param(self, "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + self.names = Param(self, "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + + "with indices.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + Sets params for this VectorSlicer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.6.0") + def setIndices(self, value): + """ + Sets the value of :py:attr:`indices`. + """ + self._paramMap[self.indices] = value + return self + + @since("1.6.0") + def getIndices(self): + """ + Gets the value of indices or its default value. + """ + return self.getOrDefault(self.indices) + + @since("1.6.0") + def setNames(self, value): + """ + Sets the value of :py:attr:`names`. + """ + self._paramMap[self.names] = value + return self + + @since("1.6.0") + def getNames(self): + """ + Gets the value of names or its default value. + """ + return self.getOrDefault(self.names) + @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further natural language processing or machine learning process. >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[0.09461779892444...| + | b|[1.15474212169647...| + | c|[-0.3794820010662...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+--------------------+ + |word| similarity| + +----+--------------------+ + | b| 0.16782984556103436| + | c|-0.46761559092107646| + +----+--------------------+ + ... >>> model.transform(doc).head().model - DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -989,6 +1838,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, inputCol=None, outputCol=None): """ @@ -999,6 +1849,7 @@ def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setVectorSize(self, value): """ Sets the value of :py:attr:`vectorSize`. @@ -1006,12 +1857,14 @@ def setVectorSize(self, value): self._paramMap[self.vectorSize] = value return self + @since("1.4.0") def getVectorSize(self): """ Gets the value of vectorSize or its default value. """ return self.getOrDefault(self.vectorSize) + @since("1.4.0") def setNumPartitions(self, value): """ Sets the value of :py:attr:`numPartitions`. @@ -1019,12 +1872,14 @@ def setNumPartitions(self, value): self._paramMap[self.numPartitions] = value return self + @since("1.4.0") def getNumPartitions(self): """ Gets the value of numPartitions or its default value. """ return self.getOrDefault(self.numPartitions) + @since("1.4.0") def setMinCount(self, value): """ Sets the value of :py:attr:`minCount`. @@ -1032,6 +1887,7 @@ def setMinCount(self, value): self._paramMap[self.minCount] = value return self + @since("1.4.0") def getMinCount(self): """ Gets the value of minCount or its default value. @@ -1044,13 +1900,39 @@ def _create_model(self, java_model): class Word2VecModel(JavaModel): """ + .. note:: Experimental + Model fitted by Word2Vec. + + .. versionadded:: 1.4.0 """ + @since("1.5.0") + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + @since("1.5.0") + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + PCA trains a model to project vectors to a low-dimensional space using PCA. >>> from pyspark.mllib.linalg import Vectors @@ -1062,6 +1944,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features DenseVector([1.648..., -4.013...]) + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -1079,6 +1963,7 @@ def __init__(self, k=None, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, k=None, inputCol=None, outputCol=None): """ setParams(self, k=None, inputCol=None, outputCol=None) @@ -1087,6 +1972,7 @@ def setParams(self, k=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -1094,6 +1980,7 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of k or its default value. @@ -1106,7 +1993,103 @@ def _create_model(self, java_model): class PCAModel(JavaModel): """ + .. note:: Experimental + Model fitted by PCA. + + .. versionadded:: 1.5.0 + """ + + +@inherit_doc +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): + """ + .. note:: Experimental + + Implements the transforms required for fitting a dataset against an + R model formula. Currently we support a limited subset of the R + operators, including '~', '.', ':', '+', and '-'. Also see the R formula + docs: + http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + + >>> df = sqlContext.createDataFrame([ + ... (1.0, 1.0, "a"), + ... (0.0, 2.0, "b"), + ... (0.0, 0.0, "a") + ... ], ["y", "x", "s"]) + >>> rf = RFormula(formula="y ~ x + s") + >>> rf.fit(df).transform(df).show() + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ + ... + >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() + +---+---+---+--------+-----+ + | y| x| s|features|label| + +---+---+---+--------+-----+ + |1.0|1.0| a| [1.0]| 1.0| + |0.0|2.0| b| [2.0]| 0.0| + |0.0|0.0| a| [0.0]| 0.0| + +---+---+---+--------+-----+ + ... + + .. versionadded:: 1.5.0 + """ + + # a placeholder to make it appear in the generated doc + formula = Param(Params._dummy(), "formula", "R model formula") + + @keyword_only + def __init__(self, formula=None, featuresCol="features", labelCol="label"): + """ + __init__(self, formula=None, featuresCol="features", labelCol="label") + """ + super(RFormula, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self.formula = Param(self, "formula", "R model formula") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.5.0") + def setParams(self, formula=None, featuresCol="features", labelCol="label"): + """ + setParams(self, formula=None, featuresCol="features", labelCol="label") + Sets params for RFormula. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("1.5.0") + def setFormula(self, value): + """ + Sets the value of :py:attr:`formula`. + """ + self._paramMap[self.formula] = value + return self + + @since("1.5.0") + def getFormula(self): + """ + Gets the value of :py:attr:`formula`. + """ + return self.getOrDefault(self.formula) + + def _create_model(self, java_model): + return RFormulaModel(java_model) + + +class RFormulaModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`RFormula`. + + .. versionadded:: 1.5.0 """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 7845536161e0..35c9b776a3d5 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -18,6 +18,7 @@ from abc import ABCMeta import copy +from pyspark import since from pyspark.ml.util import Identifiable @@ -27,6 +28,8 @@ class Param(object): """ A param with self-contained documentation. + + .. versionadded:: 1.3.0 """ def __init__(self, parent, name, doc): @@ -56,20 +59,25 @@ class Params(Identifiable): """ Components that take parameters. This also provides an internal param map to store parameter values attached to the instance. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta - #: internal param map for user-supplied values param map - _paramMap = {} + def __init__(self): + super(Params, self).__init__() + #: internal param map for user-supplied values param map + self._paramMap = {} - #: internal param map for default values - _defaultParamMap = {} + #: internal param map for default values + self._defaultParamMap = {} - #: value returned by :py:func:`params` - _params = None + #: value returned by :py:func:`params` + self._params = None @property + @since("1.3.0") def params(self): """ Returns all params ordered by name. The default implementation @@ -81,6 +89,7 @@ def params(self): [getattr(self, x) for x in dir(self) if x != "params"])) return self._params + @since("1.4.0") def explainParam(self, param): """ Explains a single param and returns its name, doc, and optional @@ -98,6 +107,7 @@ def explainParam(self, param): valueStr = "(" + ", ".join(values) + ")" return "%s: %s %s" % (param.name, param.doc, valueStr) + @since("1.4.0") def explainParams(self): """ Returns the documentation of all params with their optionally @@ -105,6 +115,7 @@ def explainParams(self): """ return "\n".join([self.explainParam(param) for param in self.params]) + @since("1.4.0") def getParam(self, paramName): """ Gets a param by its name. @@ -115,6 +126,7 @@ def getParam(self, paramName): else: raise ValueError("Cannot find param with name %s." % paramName) + @since("1.4.0") def isSet(self, param): """ Checks whether a param is explicitly set by user. @@ -122,6 +134,7 @@ def isSet(self, param): param = self._resolveParam(param) return param in self._paramMap + @since("1.4.0") def hasDefault(self, param): """ Checks whether a param has a default value. @@ -129,6 +142,7 @@ def hasDefault(self, param): param = self._resolveParam(param) return param in self._defaultParamMap + @since("1.4.0") def isDefined(self, param): """ Checks whether a param is explicitly set by user or has @@ -136,6 +150,7 @@ def isDefined(self, param): """ return self.isSet(param) or self.hasDefault(param) + @since("1.4.0") def hasParam(self, paramName): """ Tests whether this instance contains a param with a given @@ -144,6 +159,7 @@ def hasParam(self, paramName): param = self._resolveParam(paramName) return param in self.params + @since("1.4.0") def getOrDefault(self, param): """ Gets the value of a param in the user-supplied param map or its @@ -155,22 +171,27 @@ def getOrDefault(self, param): else: return self._defaultParamMap[param] - def extractParamMap(self, extra={}): + @since("1.4.0") + def extractParamMap(self, extra=None): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ + if extra is None: + extra = dict() paramMap = self._defaultParamMap.copy() paramMap.update(self._paramMap) paramMap.update(extra) return paramMap - def copy(self, extra={}): + @since("1.4.0") + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -178,9 +199,12 @@ def copy(self, extra={}): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = copy.copy(self) that._paramMap = self.extractParamMap(extra) return that @@ -195,6 +219,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -233,14 +258,17 @@ def _setDefault(self, **kwargs): self._defaultParamMap[getattr(self, param)] = value return self - def _copyValues(self, to, extra={}): + def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied """ + if extra is None: + extra = dict() paramMap = self.extractParamMap(extra) for p in self.params: if p in paramMap and to.hasParam(p.name): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 69efc424ec4e..0528dc1e3a6b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -47,7 +47,7 @@ def _gen_param_header(name, doc, defaultValueStr): """ template = '''class Has$Name(Params): """ - Mixin for param $name: $doc. + Mixin for param $name: $doc """ # a placeholder to make it appear in the generated doc @@ -105,23 +105,41 @@ def get$Name(self): print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") print("from pyspark.ml.param import Param, Params\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0)", None), - ("regParam", "regularization parameter (>= 0)", None), - ("featuresCol", "features column name", "'features'"), - ("labelCol", "label column name", "'label'"), - ("predictionCol", "prediction column name", "'prediction'"), + ("maxIter", "max number of iterations (>= 0).", None), + ("regParam", "regularization parameter (>= 0).", None), + ("featuresCol", "features column name.", "'features'"), + ("labelCol", "label column name.", "'label'"), + ("predictionCol", "prediction column name.", "'prediction'"), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + "should be treated as confidences, not precise probabilities.", "'probability'"), - ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"), - ("inputCol", "input column name", None), - ("inputCols", "input column names", None), - ("outputCol", "output column name", "self.uid + '__output'"), - ("numFeatures", "number of features", None), - ("checkpointInterval", "checkpoint interval (>= 1)", None), - ("seed", "random seed", "hash(type(self).__name__)"), - ("tol", "the convergence tolerance for iterative algorithms", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None)] + ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'"), + ("inputCol", "input column name.", None), + ("inputCols", "input column names.", None), + ("outputCol", "output column name.", "self.uid + '__output'"), + ("numFeatures", "number of features.", None), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), + ("seed", "random seed.", "hash(type(self).__name__)"), + ("tol", "the convergence tolerance for iterative algorithms.", None), + ("stepSize", "Step size to be used for each iteration of optimization.", None), + ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + + "out rows with bad values), or error (which will throw an errror). More options may be " + + "added later.", None), + ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), + ("fitIntercept", "whether to fit an intercept term.", "True"), + ("standardization", "whether to standardize the training features before fitting the " + + "model.", "True"), + ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + + "predicting each class. Array must have length equal to the number of classes, with " + + "values >= 0. The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class' threshold.", None), + ("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.", None), + ("solver", "the solver algorithm for optimization. If this is not set or empty, " + + "default value is 'auto'.", "'auto'")] + code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) @@ -140,7 +158,8 @@ def get$Name(self): ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.")] decisionTreeCode = '''class DecisionTreeParams(Params): """ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 595124726366..4d960801502c 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -26,12 +26,12 @@ class HasMaxIter(Params): """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0)") + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).") def __init__(self): super(HasMaxIter, self).__init__() - #: param for max number of iterations (>= 0) - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)") + #: param for max number of iterations (>= 0). + self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).") def setMaxIter(self, value): """ @@ -53,12 +53,12 @@ class HasRegParam(Params): """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0)") + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).") def __init__(self): super(HasRegParam, self).__init__() - #: param for regularization parameter (>= 0) - self.regParam = Param(self, "regParam", "regularization parameter (>= 0)") + #: param for regularization parameter (>= 0). + self.regParam = Param(self, "regParam", "regularization parameter (>= 0).") def setRegParam(self, value): """ @@ -80,12 +80,12 @@ class HasFeaturesCol(Params): """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name") + featuresCol = Param(Params._dummy(), "featuresCol", "features column name.") def __init__(self): super(HasFeaturesCol, self).__init__() - #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name") + #: param for features column name. + self.featuresCol = Param(self, "featuresCol", "features column name.") self._setDefault(featuresCol='features') def setFeaturesCol(self, value): @@ -108,12 +108,12 @@ class HasLabelCol(Params): """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name") + labelCol = Param(Params._dummy(), "labelCol", "label column name.") def __init__(self): super(HasLabelCol, self).__init__() - #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name") + #: param for label column name. + self.labelCol = Param(self, "labelCol", "label column name.") self._setDefault(labelCol='label') def setLabelCol(self, value): @@ -136,12 +136,12 @@ class HasPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.") def __init__(self): super(HasPredictionCol, self).__init__() - #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name") + #: param for prediction column name. + self.predictionCol = Param(self, "predictionCol", "prediction column name.") self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): @@ -160,7 +160,7 @@ def getPredictionCol(self): class HasProbabilityCol(Params): """ - Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. """ # a placeholder to make it appear in the generated doc @@ -192,12 +192,12 @@ class HasRawPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") def __init__(self): super(HasRawPredictionCol, self).__init__() - #: param for raw prediction (a.k.a. confidence) column name - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + #: param for raw prediction (a.k.a. confidence) column name. + self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): @@ -220,12 +220,12 @@ class HasInputCol(Params): """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name") + inputCol = Param(Params._dummy(), "inputCol", "input column name.") def __init__(self): super(HasInputCol, self).__init__() - #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name") + #: param for input column name. + self.inputCol = Param(self, "inputCol", "input column name.") def setInputCol(self, value): """ @@ -247,12 +247,12 @@ class HasInputCols(Params): """ # a placeholder to make it appear in the generated doc - inputCols = Param(Params._dummy(), "inputCols", "input column names") + inputCols = Param(Params._dummy(), "inputCols", "input column names.") def __init__(self): super(HasInputCols, self).__init__() - #: param for input column names - self.inputCols = Param(self, "inputCols", "input column names") + #: param for input column names. + self.inputCols = Param(self, "inputCols", "input column names.") def setInputCols(self, value): """ @@ -274,12 +274,12 @@ class HasOutputCol(Params): """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name") + outputCol = Param(Params._dummy(), "outputCol", "output column name.") def __init__(self): super(HasOutputCol, self).__init__() - #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name") + #: param for output column name. + self.outputCol = Param(self, "outputCol", "output column name.") self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): @@ -302,12 +302,12 @@ class HasNumFeatures(Params): """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features") + numFeatures = Param(Params._dummy(), "numFeatures", "number of features.") def __init__(self): super(HasNumFeatures, self).__init__() - #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features") + #: param for number of features. + self.numFeatures = Param(self, "numFeatures", "number of features.") def setNumFeatures(self, value): """ @@ -325,16 +325,16 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1)") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1) - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)") + #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def setCheckpointInterval(self, value): """ @@ -356,12 +356,12 @@ class HasSeed(Params): """ # a placeholder to make it appear in the generated doc - seed = Param(Params._dummy(), "seed", "random seed") + seed = Param(Params._dummy(), "seed", "random seed.") def __init__(self): super(HasSeed, self).__init__() - #: param for random seed - self.seed = Param(self, "seed", "random seed") + #: param for random seed. + self.seed = Param(self, "seed", "random seed.") self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): @@ -384,12 +384,12 @@ class HasTol(Params): """ # a placeholder to make it appear in the generated doc - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms") + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.") def __init__(self): super(HasTol, self).__init__() - #: param for the convergence tolerance for iterative algorithms - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms") + #: param for the convergence tolerance for iterative algorithms. + self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.") def setTol(self, value): """ @@ -407,7 +407,7 @@ def getTol(self): class HasStepSize(Params): """ - Mixin for param stepSize: Step size to be used for each iteration of optimization.. + Mixin for param stepSize: Step size to be used for each iteration of optimization. """ # a placeholder to make it appear in the generated doc @@ -432,6 +432,199 @@ def getStepSize(self): return self.getOrDefault(self.stepSize) +class HasHandleInvalid(Params): + """ + Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + """ + + # a placeholder to make it appear in the generated doc + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def __init__(self): + super(HasHandleInvalid, self).__init__() + #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + self._paramMap[self.handleInvalid] = value + return self + + def getHandleInvalid(self): + """ + Gets the value of handleInvalid or its default value. + """ + return self.getOrDefault(self.handleInvalid) + + +class HasElasticNetParam(Params): + """ + Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + def __init__(self): + super(HasElasticNetParam, self).__init__() + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(elasticNetParam=0.0) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class HasFitIntercept(Params): + """ + Mixin for param fitIntercept: whether to fit an intercept term. + """ + + # a placeholder to make it appear in the generated doc + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + + def __init__(self): + super(HasFitIntercept, self).__init__() + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self._setDefault(fitIntercept=True) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + +class HasStandardization(Params): + """ + Mixin for param standardization: whether to standardize the training features before fitting the model. + """ + + # a placeholder to make it appear in the generated doc + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + + def __init__(self): + super(HasStandardization, self).__init__() + #: param for whether to standardize the training features before fitting the model. + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self._setDefault(standardization=True) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + self._paramMap[self.standardization] = value + return self + + def getStandardization(self): + """ + Gets the value of standardization or its default value. + """ + return self.getOrDefault(self.standardization) + + +class HasThresholds(Params): + """ + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + """ + + # a placeholder to make it appear in the generated doc + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def __init__(self): + super(HasThresholds, self).__init__() + #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value + return self + + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + + +class HasWeightCol(Params): + """ + Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. + """ + + # a placeholder to make it appear in the generated doc + weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + def __init__(self): + super(HasWeightCol, self).__init__() + #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. + self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + self._paramMap[self.weightCol] = value + return self + + def getWeightCol(self): + """ + Gets the value of weightCol or its default value. + """ + return self.getOrDefault(self.weightCol) + + +class HasSolver(Params): + """ + Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + """ + + # a placeholder to make it appear in the generated doc + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + def __init__(self): + super(HasSolver, self).__init__() + #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self._setDefault(solver='auto') + + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + self._paramMap[self.solver] = value + return self + + def getSolver(self): + """ + Gets the value of solver or its default value. + """ + return self.getOrDefault(self.solver) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. @@ -443,8 +636,8 @@ class DecisionTreeParams(Params): minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -458,9 +651,9 @@ def __init__(self): self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") #: param for Maximum memory in MB allocated to histogram aggregation. self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9889f56cac9e..4475451edb78 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -17,6 +17,7 @@ from abc import ABCMeta, abstractmethod +from pyspark import since from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc @@ -26,6 +27,8 @@ class Estimator(Params): """ Abstract class for estimators that fit models to data. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta @@ -42,6 +45,7 @@ def _fit(self, dataset): """ raise NotImplementedError() + @since("1.3.0") def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -73,6 +77,8 @@ class Transformer(Params): """ Abstract class for transformers that transform one dataset into another. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta @@ -88,6 +94,7 @@ def _transform(self, dataset): """ raise NotImplementedError() + @since("1.3.0") def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -113,6 +120,8 @@ def transform(self, dataset, params=None): class Model(Transformer): """ Abstract class for models that are fitted by estimators. + + .. versionadded:: 1.4.0 """ __metaclass__ = ABCMeta @@ -136,12 +145,14 @@ class Pipeline(Estimator): consists of fitted models and transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as an identity transformer. + + .. versionadded:: 1.3.0 """ @keyword_only def __init__(self, stages=None): """ - __init__(self, stages=[]) + __init__(self, stages=None) """ if stages is None: stages = [] @@ -151,15 +162,18 @@ def __init__(self, stages=None): kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + @since("1.3.0") def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ self._paramMap[self.stages] = value return self + @since("1.3.0") def getStages(self): """ Get pipeline stages. @@ -168,9 +182,10 @@ def getStages(self): return self._paramMap[self.stages] @keyword_only + @since("1.3.0") def setParams(self, stages=None): """ - setParams(self, stages=[]) + setParams(self, stages=None) Sets params for Pipeline. """ if stages is None: @@ -203,7 +218,14 @@ def _fit(self, dataset): transformers.append(stage) return PipelineModel(transformers) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance. + + :param extra: extra parameters + :returns: new instance + """ if extra is None: extra = dict() that = Params.copy(self, extra) @@ -215,6 +237,8 @@ def copy(self, extra=None): class PipelineModel(Model): """ Represents a compiled pipeline with transformers and fitted models. + + .. versionadded:: 1.3.0 """ def __init__(self, stages): @@ -226,7 +250,14 @@ def _transform(self, dataset): dataset = t.transform(dataset) return dataset + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance. + + :param extra: extra parameters + :returns: new instance + """ if extra is None: extra = dict() stages = [stage.copy(extra) for stage in self.stages] diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee..b44c66f73cc4 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -75,11 +76,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, prediction=0.39...) + Row(user=0, item=2, prediction=-0.13807615637779236) >>> predictions[1] - Row(user=1, item=0, prediction=3.19...) + Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] - Row(user=2, item=0, prediction=-1.15...) + Row(user=2, item=0, prediction=-1.5018409490585327) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 44f60a769566..a0bb8ceed886 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,25 +15,32 @@ # limitations under the License. # +import warnings + +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', - 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', +__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', + 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', + 'GBTRegressor', 'GBTRegressionModel', + 'IsotonicRegression', 'IsotonicRegressionModel', + 'LinearRegression', 'LinearRegressionModel', 'RandomForestRegressor', 'RandomForestRegressionModel'] @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol): + HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) @@ -43,58 +50,53 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0) + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) - >>> model.transform(test0).head().prediction - -1.0 - >>> model.weights - DenseVector([1.0]) - >>> model.intercept - 0.0 + >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 + True + >>> abs(model.coefficients[0] - 1.0) < 0.001 + True + >>> abs(model.intercept - 0.0) < 0.001 + True >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) - >>> model.transform(test1).head().prediction - 1.0 + >>> abs(model.transform(test1).head().prediction - 1.0) < 0.001 + True >>> lr.setParams("vector") Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. - """ - # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + .. versionadded:: 1.4.0 + """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,33 +105,34 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. """ + + warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -137,21 +140,244 @@ def intercept(self): return self._call_java("intercept") -class TreeRegressorParams(object): +@inherit_doc +class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasWeightCol): + """ + .. note:: Experimental + + Currently implemented using parallelized pool adjacent violators algorithm. + Only univariate (single feature) algorithm supported. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> ir = IsotonicRegression() + >>> model = ir.fit(df) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> model.boundaries + DenseVector([0.0, 1.0]) + """ + + # a placeholder to make it appear in the generated doc + isotonic = \ + Param(Params._dummy(), "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false).") + featureIndex = \ + Param(Params._dummy(), "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect otherwise.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + weightCol=None, isotonic=True, featureIndex=0): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + weightCol=None, isotonic=True, featureIndex=0): + """ + super(IsotonicRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.IsotonicRegression", self.uid) + self.isotonic = \ + Param(self, "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false).") + self.featureIndex = \ + Param(self, "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect " + + "otherwise.") + self._setDefault(isotonic=True, featureIndex=0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + weightCol=None, isotonic=True, featureIndex=0): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + weightCol=None, isotonic=True, featureIndex=0): + Set the params for IsotonicRegression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return IsotonicRegressionModel(java_model) + + def setIsotonic(self, value): + """ + Sets the value of :py:attr:`isotonic`. + """ + self._paramMap[self.isotonic] = value + return self + + def getIsotonic(self): + """ + Gets the value of isotonic or its default value. + """ + return self.getOrDefault(self.isotonic) + + def setFeatureIndex(self, value): + """ + Sets the value of :py:attr:`featureIndex`. + """ + self._paramMap[self.featureIndex] = value + return self + + def getFeatureIndex(self): + """ + Gets the value of featureIndex or its default value. + """ + return self.getOrDefault(self.featureIndex) + + +class IsotonicRegressionModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by IsotonicRegression. + """ + + @property + def boundaries(self): + """ + Model boundaries. + """ + return self._call_java("boundaries") + + @property + def predictions(self): + """ + Predictions associated with the boundaries at the same index, monotone because of isotonic + regression. + """ + return self._call_java("predictions") + + +class TreeEnsembleParams(DecisionTreeParams): + """ + Mixin for Decision Tree-based ensemble algorithms parameters. + """ + + # a placeholder to make it appear in the generated doc + subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + + "used for learning each decision tree, in range (0, 1].") + + def __init__(self): + super(TreeEnsembleParams, self).__init__() + #: param for Fraction of the training data, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " + + "used for learning each decision tree, in range (0, 1].") + + @since("1.4.0") + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + @since("1.4.0") + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + +class TreeRegressorParams(Params): """ Private class to track supported impurity measures. """ + supportedImpurities = ["variance"] + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + + ", ".join(supportedImpurities)) + + def __init__(self): + super(TreeRegressorParams, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = Param(self, "impurity", "Criterion used for information " + + "gain calculation (case-insensitive). Supported options: " + + ", ".join(self.supportedImpurities)) + + @since("1.4.0") + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + @since("1.4.0") + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) -class RandomForestParams(object): +class RandomForestParams(TreeEnsembleParams): """ Private class to track supported random forest parameters. """ + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + # a placeholder to make it appear in the generated doc + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies)) + + def __init__(self): + super(RandomForestParams, self).__init__() + #: param for Number of trees to train (>= 1). + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).") + #: param for The number of features to consider for splits at each tree node. + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(self.supportedFeatureSubsetStrategies)) + + @since("1.4.0") + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self._paramMap[self.numTrees] = value + return self + + @since("1.4.0") + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self._paramMap[self.featureSubsetStrategy] = value + return self -class GBTParams(object): + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. """ @@ -160,7 +386,7 @@ class GBTParams(object): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -182,12 +408,9 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 - """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) + .. versionadded:: 1.4.0 + """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -201,11 +424,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") @@ -213,6 +431,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -229,29 +448,22 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -262,8 +474,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -276,12 +493,14 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - DecisionTreeParams, HasCheckpointInterval): + RandomForestParams, TreeRegressorParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for regression. @@ -302,69 +521,46 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 - """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + .. versionadded:: 1.4.0 + """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=None): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", numTrees=20, \ - featureSubsetStrategy="auto", seed=None) + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ + featureSubsetStrategy="auto") """ super(RandomForestRegressor, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) - #: param for Fraction of the training data used for learning each decision tree, - # in range (0, 1] - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: param for Number of trees to train (>= 1) - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") - #: param for The number of features to consider for splits at each tree node - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="variance", numTrees=20, featureSubsetStrategy="auto") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="variance", numTrees=20, featureSubsetStrategy="auto"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ - impurity="variance", numTrees=20, featureSubsetStrategy="auto") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ + featureSubsetStrategy="auto") Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -373,68 +569,18 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - def setNumTrees(self, value): - """ - Sets the value of :py:attr:`numTrees`. - """ - self._paramMap[self.numTrees] = value - return self - - def getNumTrees(self): - """ - Gets the value of numTrees or its default value. - """ - return self.getOrDefault(self.numTrees) - - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - self._paramMap[self.featureSubsetStrategy] = value - return self - - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - DecisionTreeParams, HasCheckpointInterval): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -455,29 +601,25 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + - "contribution of each estimator") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", - maxIter=20, stepSize=0.1): + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) @@ -485,31 +627,23 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - #: Fraction of the training data used for learning each decision tree, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of - # each estimator - self.stepSize = Param(self, "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator") self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="squared", maxIter=20, stepSize=0.1): + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) Sets params for Gradient Boosted Tree Regression. """ kwargs = self.setParams._input_kwargs @@ -518,6 +652,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -525,44 +660,206 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) - def setSubsamplingRate(self, value): + +class GBTRegressionModel(TreeEnsembleModels): + """ + Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 + """ + + +@inherit_doc +class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol): + """ + Accelerated Failure Time (AFT) Model Survival Regression + + Fit a parametric AFT survival regression model based on the Weibull distribution + of the survival time. + + .. seealso:: `AFT Model `_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0), 1.0), + ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) + >>> aftsr = AFTSurvivalRegression() + >>> model = aftsr.fit(df) + >>> model.predict(Vectors.dense(6.3)) + 1.0 + >>> model.predictQuantiles(Vectors.dense(6.3)) + DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052]) + >>> model.transform(df).show() + +-----+---------+------+----------+ + |label| features|censor|prediction| + +-----+---------+------+----------+ + | 1.0| [1.0]| 1.0| 1.0| + | 0.0|(1,[],[])| 0.0| 1.0| + +-----+---------+------+----------+ + ... + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + censorCol = Param(Params._dummy(), "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") + quantileProbabilities = \ + Param(Params._dummy(), "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.") + quantilesCol = Param(Params._dummy(), "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=None, quantilesCol=None): """ - Sets the value of :py:attr:`subsamplingRate`. + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None) """ - self._paramMap[self.subsamplingRate] = value + super(AFTSurvivalRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) + #: Param for censor column name + self.censorCol = Param(self, "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") + #: Param for quantile probabilities array + self.quantileProbabilities = \ + Param(self, "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.") + #: Param for quantiles column name + self.quantilesCol = Param(self, "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") + self._setDefault(censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=None, quantilesCol=None): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None): + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return AFTSurvivalRegressionModel(java_model) + + @since("1.6.0") + def setCensorCol(self, value): + """ + Sets the value of :py:attr:`censorCol`. + """ + self._paramMap[self.censorCol] = value return self - def getSubsamplingRate(self): + @since("1.6.0") + def getCensorCol(self): """ - Gets the value of subsamplingRate or its default value. + Gets the value of censorCol or its default value. """ - return self.getOrDefault(self.subsamplingRate) + return self.getOrDefault(self.censorCol) - def setStepSize(self, value): + @since("1.6.0") + def setQuantileProbabilities(self, value): """ - Sets the value of :py:attr:`stepSize`. + Sets the value of :py:attr:`quantileProbabilities`. """ - self._paramMap[self.stepSize] = value + self._paramMap[self.quantileProbabilities] = value return self - def getStepSize(self): + @since("1.6.0") + def getQuantileProbabilities(self): """ - Gets the value of stepSize or its default value. + Gets the value of quantileProbabilities or its default value. """ - return self.getOrDefault(self.stepSize) + return self.getOrDefault(self.quantileProbabilities) + @since("1.6.0") + def setQuantilesCol(self, value): + """ + Sets the value of :py:attr:`quantilesCol`. + """ + self._paramMap[self.quantilesCol] = value + return self -class GBTRegressionModel(TreeEnsembleModels): + @since("1.6.0") + def getQuantilesCol(self): + """ + Gets the value of quantilesCol or its default value. + """ + return self.getOrDefault(self.quantilesCol) + + +class AFTSurvivalRegressionModel(JavaModel): """ - Model fitted by GBTRegressor. + Model fitted by AFTSurvivalRegression. + + .. versionadded:: 1.6.0 """ + @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("1.6.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + @property + @since("1.6.0") + def scale(self): + """ + Model scale paramter. + """ + return self._call_java("scale") + + def predictQuantiles(self, features): + """ + Predicted Quantiles + """ + return self._call_java("predictQuantiles", features) + + def predict(self, features): + """ + Predicted value + """ + return self._call_java("predict", features) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c151d21fd661..7a16cf52cccb 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -20,6 +20,10 @@ """ import sys +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -31,12 +35,15 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.util import keyword_only from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.feature import * +from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel from pyspark.mllib.linalg import DenseVector @@ -160,7 +167,7 @@ def test_param(self): testParams = TestParams() maxIter = testParams.maxIter self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0)") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) def test_params(self): @@ -179,7 +186,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -192,11 +199,11 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), - "\n".join(["inputCol: input column name (undefined)", - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", - "seed: random seed (default: 41, current: 43)"])) + "\n".join(["inputCol: input column name. (undefined)", + "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", + "seed: random seed. (default: 41, current: 43)"])) def test_hasseed(self): noSeedSpecd = TestParams() @@ -255,14 +262,117 @@ def test_idf(self): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) + + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEqual(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a"]) + + +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(PySparkTestCase): + + def test_fit_minimize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") if __name__ == "__main__": - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f1..08f8db57f440 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,8 +18,10 @@ import itertools import numpy as np -from pyspark.ml.param import Params, Param +from pyspark import since from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -47,11 +49,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +65,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +79,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -83,7 +90,7 @@ def build(self): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -104,6 +111,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -121,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -142,14 +153,18 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + @since("1.4.0") + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +172,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +187,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +202,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +217,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -208,9 +230,10 @@ def _fit(self, dataset): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h @@ -223,11 +246,26 @@ def _fit(self, dataset): # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric - bestIndex = np.argmax(metrics) + + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) - def copy(self, extra={}): + @since("1.4.0") + def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) @@ -240,6 +278,8 @@ def copy(self, extra={}): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -250,15 +290,19 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) - def copy(self, extra={}): + @since("1.4.0") + def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 253705bde913..4bcb4aaec89d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -136,7 +137,8 @@ def _fit(self, dataset): class JavaTransformer(Transformer, JavaWrapper): """ Base class for :py:class:`Transformer`s that wrap Java/Scala - implementations. + implementations. Subclasses should ensure they have the transformer Java object + available as _java_obj. """ __metaclass__ = ABCMeta @@ -172,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 8f27c446a66e..9e6f17ef6e94 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -20,7 +20,7 @@ import numpy from numpy import array -from pyspark import RDD +from pyspark import RDD, since from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector @@ -44,6 +44,7 @@ def __init__(self, weights, intercept): super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None + @since('1.4.0') def setThreshold(self, value): """ .. note:: Experimental @@ -57,6 +58,7 @@ def setThreshold(self, value): self._threshold = value @property + @since('1.4.0') def threshold(self): """ .. note:: Experimental @@ -67,6 +69,7 @@ def threshold(self): """ return self._threshold + @since('1.4.0') def clearThreshold(self): """ .. note:: Experimental @@ -76,6 +79,7 @@ def clearThreshold(self): """ self._threshold = None + @since('1.4.0') def predict(self, test): """ Predict values for a single data point or an RDD of points @@ -157,6 +161,8 @@ class LogisticRegressionModel(LinearClassificationModel): 1 >>> mcm.predict([0.0, 0.0, 0.3]) 2 + + .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) @@ -172,13 +178,23 @@ def __init__(self, weights, intercept, numFeatures, numClasses): self._dataWithBiasSize) @property + @since('1.4.0') def numFeatures(self): + """ + Dimension of the features. + """ return self._numFeatures @property + @since('1.4.0') def numClasses(self): + """ + Number of possible outcomes for k classes classification problem in Multinomial + Logistic Regression. + """ return self._numClasses + @since('0.9.0') def predict(self, x): """ Predict values for a single data point or an RDD of points @@ -217,13 +233,21 @@ def predict(self, x): best_class = i + 1 return best_class + @since('1.4.0') def save(self, sc, path): + """ + Save this model to the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.4.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -237,11 +261,14 @@ def load(cls, sc, path): class LogisticRegressionWithSGD(object): - + """ + .. versionadded:: 0.9.0 + """ @classmethod + @since('0.9.0') def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType="l2", intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a logistic regression model on the given data. @@ -274,18 +301,23 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) class LogisticRegressionWithLBFGS(object): - + """ + .. versionadded:: 1.2.0 + """ @classmethod + @since('1.2.0') def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ @@ -397,11 +429,14 @@ class SVMModel(LinearClassificationModel): ... rmtree(path) ... except: ... pass + + .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) self._threshold = 0.0 + @since('0.9.0') def predict(self, x): """ Predict values for a single data point or an RDD of points @@ -417,13 +452,21 @@ def predict(self, x): else: return 1 if margin > self._threshold else 0 + @since('1.4.0') def save(self, sc, path): + """ + Save this model to the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.4.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -435,11 +478,15 @@ def load(cls, sc, path): class SVMWithSGD(object): + """ + .. versionadded:: 0.9.0 + """ @classmethod + @since('0.9.0') def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType="l2", - intercept=False, validateData=True): + intercept=False, validateData=True, convergenceTol=0.001): """ Train a support vector machine on the given data. @@ -472,11 +519,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) @@ -526,13 +575,15 @@ class NaiveBayesModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass - """ + .. versionadded:: 0.9.0 + """ def __init__(self, labels, pi, theta): self.labels = labels self.pi = pi self.theta = theta + @since('0.9.0') def predict(self, x): """ Return the most likely class for a data vector @@ -544,6 +595,9 @@ def predict(self, x): return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] def save(self, sc, path): + """ + Save this model to the given path. + """ java_labels = _py2java(sc, self.labels.tolist()) java_pi = _py2java(sc, self.pi.tolist()) java_theta = _py2java(sc, self.theta.tolist()) @@ -552,7 +606,11 @@ def save(self, sc, path): java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.4.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( sc._jsc.sc(), path) # Can not unpickle array.array from Pyrolite in Python3 with "bytes" @@ -563,8 +621,12 @@ def load(cls, sc, path): class NaiveBayes(object): + """ + .. versionadded:: 0.9.0 + """ @classmethod + @since('0.9.0') def train(cls, data, lambda_=1.0): """ Train a Naive Bayes model given an RDD of (label, features) @@ -590,26 +652,45 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. + Train or predict a logistic regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) + + .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, + convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.regParam = regParam self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLogisticRegressionWithSGD, self).__init__( model=self._model) + @since('1.5.0') def setInitialWeights(self, initialWeights): """ Set the initial value of weights. @@ -623,6 +704,7 @@ def setInitialWeights(self, initialWeights): initialWeights, 0, initialWeights.size, 2) return self + @since('1.5.0') def trainOn(self, dstream): """Train the model on the incoming dstream.""" self._validate(dstream) @@ -632,7 +714,8 @@ def update(rdd): if not rdd.isEmpty(): self._model = LogisticRegressionWithSGD.train( rdd, self.numIterations, self.stepSize, - self.miniBatchFraction, self._model.weights) + self.miniBatchFraction, self._model.weights, + regParam=self.regParam, convergenceTol=self.convergenceTol) dstream.foreachRDD(update) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 900ade248c38..c9e6f1dec6bf 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,6 +17,7 @@ import sys import array as pyarray +import warnings if sys.version > '3': xrange = range @@ -28,7 +29,7 @@ from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector @@ -90,21 +91,32 @@ class KMeansModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass + + >>> data = array([-383.1,-382.9, 28.7,31.2, 366.2,367.3]).reshape(3, 2) + >>> model = KMeans.train(sc.parallelize(data), 3, maxIterations=0, + ... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)])) + >>> model.clusterCenters + [array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])] + + .. versionadded:: 0.9.0 """ def __init__(self, centers): self.centers = centers @property + @since('1.0.0') def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return self.centers @property + @since('1.4.0') def k(self): """Total number of clusters.""" return len(self.centers) + @since('0.9.0') def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 @@ -120,6 +132,7 @@ def predict(self, x): best_distance = distance return best + @since('1.4.0') def computeCost(self, rdd): """ Return the K-means cost (sum of squared distances of points to @@ -129,25 +142,47 @@ def computeCost(self, rdd): [_convert_to_vector(c) for c in self.centers]) return cost + @since('1.4.0') def save(self, sc, path): + """ + Save this model to the given path. + """ java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.4.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) return KMeansModel(_java2py(sc, java_model.clusterCenters())) class KMeans(object): + """ + .. versionadded:: 0.9.0 + """ @classmethod + @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", - seed=None, initializationSteps=5, epsilon=1e-4): + seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): """Train a k-means clustering model.""" + if runs != 1: + warnings.warn( + "Support for runs is deprecated in 1.6.0. This param will have no effect in 1.7.0.") + clusterInitialModel = [] + if initialModel is not None: + if not isinstance(initialModel, KMeansModel): + raise Exception("initialModel is of "+str(type(initialModel))+". It needs " + "to be of ") + clusterInitialModel = [_convert_to_vector(c) for c in initialModel.clusterCenters] model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode, seed, initializationSteps, epsilon) + runs, initializationMode, seed, initializationSteps, epsilon, + clusterInitialModel) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) @@ -205,13 +240,16 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() - >>> labels[0]==labels[1]==labels[2] + >>> labels[0]==labels[1] True - >>> labels[3]==labels[4] + >>> labels[2]==labels[3]==labels[4] True + + .. versionadded:: 1.3.0 """ @property + @since('1.4.0') def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is @@ -220,6 +258,7 @@ def weights(self): return array(self.call("weights")) @property + @since('1.4.0') def gaussians(self): """ Array of MultivariateGaussian where gaussians[i] represents @@ -227,13 +266,15 @@ def gaussians(self): """ return [ MultivariateGaussian(gaussian[0], gaussian[1]) - for gaussian in zip(*self.call("gaussians"))] + for gaussian in self.call("gaussians")] @property + @since('1.4.0') def k(self): """Number of gaussians in mixture.""" return len(self.weights) + @since('1.3.0') def predict(self, x): """ Find the cluster to which the points in 'x' has maximum membership @@ -249,6 +290,7 @@ def predict(self, x): raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @since('1.3.0') def predictSoft(self, x): """ Find the membership of each point in 'x' to all mixture components. @@ -266,6 +308,7 @@ def predictSoft(self, x): "but got %s." % type(x)) @classmethod + @since('1.5.0') def load(cls, sc, path): """Load the GaussianMixtureModel from disk. @@ -289,8 +332,11 @@ class GaussianMixture(object): :param maxIterations: Number of iterations. Default to 100 :param seed: Random Seed :param initialModel: GaussianMixtureModel for initializing learning + + .. versionadded:: 1.3.0 """ @classmethod + @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """Train a Gaussian Mixture clustering model.""" initialModelWeights = None @@ -345,15 +391,19 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.5.0 """ @property + @since('1.5.0') def k(self): """ Returns the number of clusters. """ return self.call("k") + @since('1.5.0') def assignments(self): """ Returns the cluster assignments of this model. @@ -362,7 +412,11 @@ def assignments(self): lambda x: (PowerIterationClustering.Assignment(*x))) @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ model = cls._load_java(sc, path) wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) return PowerIterationClusteringModel(wrapper) @@ -377,9 +431,12 @@ class PowerIterationClustering(object): From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. + + .. versionadded:: 1.5.0 """ @classmethod + @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the @@ -402,6 +459,8 @@ def train(cls, rdd, k, maxIterations=100, initMode="random"): class Assignment(namedtuple("Assignment", ["id", "cluster"])): """ Represents an (id, cluster) tuple. + + .. versionadded:: 1.5.0 """ @@ -461,17 +520,21 @@ class StreamingKMeansModel(KMeansModel): 0 >>> stkm.predict([1.5, 1.5]) 1 + + .. versionadded:: 1.5.0 """ def __init__(self, clusterCenters, clusterWeights): super(StreamingKMeansModel, self).__init__(centers=clusterCenters) self._clusterWeights = list(clusterWeights) @property + @since('1.5.0') def clusterWeights(self): """Return the cluster weights.""" return self._clusterWeights @ignore_unicode_prefix + @since('1.5.0') def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data @@ -510,6 +573,8 @@ class StreamingKMeans(object): :param decayFactor: float, forgetfulness of the previous centroids. :param timeUnit: can be "batches" or "points". If points, then the decayfactor is raised to the power of no. of new points. + + .. versionadded:: 1.5.0 """ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): self._k = k @@ -520,6 +585,7 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): self._timeUnit = timeUnit self._model = None + @since('1.5.0') def latestModel(self): """Return the latest model""" return self._model @@ -534,16 +600,19 @@ def _validate(self, dstream): "Expected dstream to be of type DStream, " "got type %s" % type(dstream)) + @since('1.5.0') def setK(self, k): """Set number of clusters.""" self._k = k return self + @since('1.5.0') def setDecayFactor(self, decayFactor): """Set decay factor.""" self._decayFactor = decayFactor return self + @since('1.5.0') def setHalfLife(self, halfLife, timeUnit): """ Set number of batches after which the centroids of that @@ -553,6 +622,7 @@ def setHalfLife(self, halfLife, timeUnit): self._decayFactor = exp(log(0.5) / halfLife) return self + @since('1.5.0') def setInitialCenters(self, centers, weights): """ Set initial centers. Should be set before calling trainOn. @@ -560,6 +630,7 @@ def setInitialCenters(self, centers, weights): self._model = StreamingKMeansModel(centers, weights) return self + @since('1.5.0') def setRandomCenters(self, dim, weight, seed): """ Set the initial centres to be random samples from @@ -571,6 +642,7 @@ def setRandomCenters(self, dim, weight, seed): self._model = StreamingKMeansModel(clusterCenters, clusterWeights) return self + @since('1.5.0') def trainOn(self, dstream): """Train the model on the incoming dstream.""" self._validate(dstream) @@ -580,6 +652,7 @@ def update(rdd): dstream.foreachRDD(update) + @since('1.5.0') def predictOn(self, dstream): """ Make predictions on a dstream. @@ -588,6 +661,7 @@ def predictOn(self, dstream): self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) + @since('1.5.0') def predictOnValues(self, dstream): """ Make predictions on a keyed dstream. @@ -597,7 +671,7 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) -class LDAModel(JavaModelWrapper): +class LDAModel(JavaModelWrapper, JavaSaveable, Loader): """ A clustering model derived from the LDA method. @@ -617,9 +691,14 @@ class LDAModel(JavaModelWrapper): ... [2, SparseVector(2, {0: 1.0})], ... ] >>> rdd = sc.parallelize(data) - >>> model = LDA.train(rdd, k=2) + >>> model = LDA.train(rdd, k=2, seed=1) >>> model.vocabSize() 2 + >>> model.describeTopics() + [([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])] + >>> model.describeTopics(1) + [([1], [0.5...]), ([0], [0.5...])] + >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) @@ -636,29 +715,40 @@ class LDAModel(JavaModelWrapper): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.5.0 """ + @since('1.5.0') def topicsMatrix(self): """Inferred topics, where each topic is represented by a distribution over terms.""" return self.call("topicsMatrix").toArray() + @since('1.5.0') def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - def save(self, sc, path): - """Save the LDAModel on to disk. + @since('1.6.0') + def describeTopics(self, maxTermsPerTopic=None): + """Return the topics described by weighted terms. - :param sc: SparkContext - :param path: str, path to where the model needs to be stored. + WARNING: If vocabSize and k are large, this can return a large object! + + :param maxTermsPerTopic: Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: Array over topics. Each topic is represented as a pair of matching arrays: + (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ - if not isinstance(sc, SparkContext): - raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_model.save(sc._jsc.sc(), path) + if maxTermsPerTopic is None: + topics = self.call("describeTopics") + else: + topics = self.call("describeTopics", maxTermsPerTopic) + return topics @classmethod + @since('1.5.0') def load(cls, sc, path): """Load the LDAModel from disk. @@ -669,14 +759,17 @@ def load(cls, sc, path): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( - sc._jsc.sc(), path) - return cls(java_model) + model = callMLlibFunc("loadLDAModel", sc, path) + return LDAModel(model) class LDA(object): + """ + .. versionadded:: 1.5.0 + """ @classmethod + @since('1.5.0') def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 855e85f57155..9fda1b1682f5 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -73,6 +73,8 @@ def _py2java(sc, obj): """ Convert Python object into Java """ if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) + elif isinstance(obj, DataFrame): + obj = obj._jdf elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): @@ -100,7 +102,7 @@ def _java2py(sc, r, encoding="bytes"): return RDD(jrdd, sc) if clsName == 'DataFrame': - return DataFrame(r, SQLContext(sc)) + return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) @@ -123,7 +125,7 @@ def callJavaFunc(sc, func, *args): def callMLlibFunc(name, *args): """ Call API in PythonMLLibAPI """ - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) @@ -133,7 +135,7 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ def __init__(self, java_model): - self._sc = SparkContext._active_spark_context + self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec..22e68ea5b451 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.sql import SQLContext from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType @@ -37,11 +38,13 @@ class BinaryClassificationMetrics(JavaModelWrapper): >>> metrics.areaUnderPR 0.83... >>> metrics.unpersist() + + .. versionadded:: 1.4.0 """ def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ StructField("score", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -50,6 +53,7 @@ def __init__(self, scoreAndLabels): super(BinaryClassificationMetrics, self).__init__(java_model) @property + @since('1.4.0') def areaUnderROC(self): """ Computes the area under the receiver operating characteristic @@ -58,12 +62,14 @@ def areaUnderROC(self): return self.call("areaUnderROC") @property + @since('1.4.0') def areaUnderPR(self): """ Computes the area under the precision-recall curve. """ return self.call("areaUnderPR") + @since('1.4.0') def unpersist(self): """ Unpersists intermediate RDDs used in the computation. @@ -91,11 +97,13 @@ class RegressionMetrics(JavaModelWrapper): 0.61... >>> metrics.r2 0.94... + + .. versionadded:: 1.4.0 """ def __init__(self, predictionAndObservations): sc = predictionAndObservations.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("observation", DoubleType(), nullable=False)])) @@ -104,6 +112,7 @@ def __init__(self, predictionAndObservations): super(RegressionMetrics, self).__init__(java_model) @property + @since('1.4.0') def explainedVariance(self): """ Returns the explained variance regression score. @@ -112,6 +121,7 @@ def explainedVariance(self): return self.call("explainedVariance") @property + @since('1.4.0') def meanAbsoluteError(self): """ Returns the mean absolute error, which is a risk function corresponding to the @@ -120,6 +130,7 @@ def meanAbsoluteError(self): return self.call("meanAbsoluteError") @property + @since('1.4.0') def meanSquaredError(self): """ Returns the mean squared error, which is a risk function corresponding to the @@ -128,6 +139,7 @@ def meanSquaredError(self): return self.call("meanSquaredError") @property + @since('1.4.0') def rootMeanSquaredError(self): """ Returns the root mean squared error, which is defined as the square root of @@ -136,6 +148,7 @@ def rootMeanSquaredError(self): return self.call("rootMeanSquaredError") @property + @since('1.4.0') def r2(self): """ Returns R^2^, the coefficient of determination. @@ -147,7 +160,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) @@ -178,11 +191,13 @@ class MulticlassMetrics(JavaModelWrapper): 0.66... >>> metrics.weightedFMeasure(2.0) 0.65... + + .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -190,6 +205,7 @@ def __init__(self, predictionAndLabels): java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model) + @since('1.4.0') def confusionMatrix(self): """ Returns confusion matrix: predicted classes are in columns, @@ -197,18 +213,21 @@ def confusionMatrix(self): """ return self.call("confusionMatrix") + @since('1.4.0') def truePositiveRate(self, label): """ Returns true positive rate for a given label (category). """ return self.call("truePositiveRate", label) + @since('1.4.0') def falsePositiveRate(self, label): """ Returns false positive rate for a given label (category). """ return self.call("falsePositiveRate", label) + @since('1.4.0') def precision(self, label=None): """ Returns precision or precision for a given label (category) if specified. @@ -218,6 +237,7 @@ def precision(self, label=None): else: return self.call("precision", float(label)) + @since('1.4.0') def recall(self, label=None): """ Returns recall or recall for a given label (category) if specified. @@ -227,6 +247,7 @@ def recall(self, label=None): else: return self.call("recall", float(label)) + @since('1.4.0') def fMeasure(self, label=None, beta=None): """ Returns f-measure or f-measure for a given label (category) if specified. @@ -243,6 +264,7 @@ def fMeasure(self, label=None, beta=None): return self.call("fMeasure", label, beta) @property + @since('1.4.0') def weightedTruePositiveRate(self): """ Returns weighted true positive rate. @@ -251,6 +273,7 @@ def weightedTruePositiveRate(self): return self.call("weightedTruePositiveRate") @property + @since('1.4.0') def weightedFalsePositiveRate(self): """ Returns weighted false positive rate. @@ -258,6 +281,7 @@ def weightedFalsePositiveRate(self): return self.call("weightedFalsePositiveRate") @property + @since('1.4.0') def weightedRecall(self): """ Returns weighted averaged recall. @@ -266,12 +290,14 @@ def weightedRecall(self): return self.call("weightedRecall") @property + @since('1.4.0') def weightedPrecision(self): """ Returns weighted averaged precision. """ return self.call("weightedPrecision") + @since('1.4.0') def weightedFMeasure(self, beta=None): """ Returns weighted averaged f-measure. @@ -307,16 +333,18 @@ class RankingMetrics(JavaModelWrapper): >>> metrics.ndcgAt(10) 0.48... + .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_model = callMLlibFunc("newRankingMetrics", df._jdf) super(RankingMetrics, self).__init__(java_model) + @since('1.4.0') def precisionAt(self, k): """ Compute the average precision of all the queries, truncated at ranking position k. @@ -331,6 +359,7 @@ def precisionAt(self, k): return self.call("precisionAt", int(k)) @property + @since('1.4.0') def meanAveragePrecision(self): """ Returns the mean average precision (MAP) of all the queries. @@ -339,6 +368,7 @@ def meanAveragePrecision(self): """ return self.call("meanAveragePrecision") + @since('1.4.0') def ndcgAt(self, k): """ Compute the average NDCG value of all the queries, truncated at ranking position k. @@ -388,17 +418,20 @@ class MultilabelMetrics(JavaModelWrapper): 0.28... >>> metrics.accuracy 0.54... + + .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics java_model = java_class(df._jdf) super(MultilabelMetrics, self).__init__(java_model) + @since('1.4.0') def precision(self, label=None): """ Returns precision or precision for a given label (category) if specified. @@ -408,6 +441,7 @@ def precision(self, label=None): else: return self.call("precision", float(label)) + @since('1.4.0') def recall(self, label=None): """ Returns recall or recall for a given label (category) if specified. @@ -417,6 +451,7 @@ def recall(self, label=None): else: return self.call("recall", float(label)) + @since('1.4.0') def f1Measure(self, label=None): """ Returns f1Measure or f1Measure for a given label (category) if specified. @@ -427,6 +462,7 @@ def f1Measure(self, label=None): return self.call("f1Measure", float(label)) @property + @since('1.4.0') def microPrecision(self): """ Returns micro-averaged label-based precision. @@ -435,6 +471,7 @@ def microPrecision(self): return self.call("microPrecision") @property + @since('1.4.0') def microRecall(self): """ Returns micro-averaged label-based recall. @@ -443,6 +480,7 @@ def microRecall(self): return self.call("microRecall") @property + @since('1.4.0') def microF1Measure(self): """ Returns micro-averaged label-based f1-measure. @@ -451,6 +489,7 @@ def microF1Measure(self): return self.call("microF1Measure") @property + @since('1.4.0') def hammingLoss(self): """ Returns Hamming-loss. @@ -458,6 +497,7 @@ def hammingLoss(self): return self.call("hammingLoss") @property + @since('1.4.0') def subsetAccuracy(self): """ Returns subset accuracy. @@ -466,6 +506,7 @@ def subsetAccuracy(self): return self.call("subsetAccuracy") @property + @since('1.4.0') def accuracy(self): """ Returns accuracy. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f921e3ad1a31..acd7ec57d69d 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext +from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -84,11 +84,14 @@ class Normalizer(VectorTransformer): >>> nor2 = Normalizer(float("inf")) >>> nor2.transform(v) DenseVector([0.0, 0.5, 1.0]) + + .. versionadded:: 1.2.0 """ def __init__(self, p=2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) + @since('1.2.0') def transform(self, vector): """ Applies unit length normalization on a vector. @@ -97,8 +100,6 @@ def transform(self, vector): :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ - sc = SparkContext._active_spark_context - assert sc is not None, "SparkContext should be initialized first" if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -133,7 +134,11 @@ class StandardScalerModel(JavaVectorTransformer): .. note:: Experimental Represents a StandardScaler model that can transform vectors. + + .. versionadded:: 1.2.0 """ + + @since('1.2.0') def transform(self, vector): """ Applies standardization transformation on a vector. @@ -149,6 +154,7 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + @since('1.4.0') def setWithMean(self, withMean): """ Setter of the boolean which decides @@ -157,6 +163,7 @@ def setWithMean(self, withMean): self.call("setWithMean", withMean) return self + @since('1.4.0') def setWithStd(self, withStd): """ Setter of the boolean which decides @@ -189,6 +196,8 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + + .. versionadded:: 1.2.0 """ def __init__(self, withMean=False, withStd=True): if not (withMean or withStd): @@ -196,6 +205,7 @@ def __init__(self, withMean=False, withStd=True): self.withMean = withMean self.withStd = withStd + @since('1.2.0') def fit(self, dataset): """ Computes the mean and variance and stores as a model to be used @@ -215,7 +225,11 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. note:: Experimental Represents a Chi Squared selector model. + + .. versionadded:: 1.4.0 """ + + @since('1.4.0') def transform(self, vector): """ Applies transformation on a vector. @@ -245,10 +259,13 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) + + .. versionadded:: 1.4.0 """ def __init__(self, numTopFeatures): self.numTopFeatures = int(numTopFeatures) + @since('1.4.0') def fit(self, data): """ Returns a ChiSquared feature selector. @@ -265,6 +282,8 @@ def fit(self, data): class PCAModel(JavaVectorTransformer): """ Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + + .. versionadded:: 1.5.0 """ @@ -281,6 +300,8 @@ class PCA(object): 1.648... >>> pcArray[1] -4.013... + + .. versionadded:: 1.5.0 """ def __init__(self, k): """ @@ -288,6 +309,7 @@ def __init__(self, k): """ self.k = int(k) + @since('1.5.0') def fit(self, data): """ Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -312,14 +334,18 @@ class HashingTF(object): >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) + + .. versionadded:: 1.2.0 """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + @since('1.2.0') def indexOf(self, term): """ Returns the index of the input term. """ return hash(term) % self.numFeatures + @since('1.2.0') def transform(self, document): """ Transforms the input document (list of terms) to term frequency @@ -339,7 +365,10 @@ def transform(self, document): class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -358,6 +387,7 @@ def transform(self, x): """ return JavaVectorTransformer.transform(self, x) + @since('1.4.0') def idf(self): """ Returns the current IDF vector. @@ -401,10 +431,13 @@ class IDF(object): DenseVector([0.0, 0.0, 1.3863, 0.863]) >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) SparseVector(4, {1: 0.0, 3: 0.5754}) + + .. versionadded:: 1.2.0 """ def __init__(self, minDocFreq=0): self.minDocFreq = minDocFreq + @since('1.2.0') def fit(self, dataset): """ Computes the inverse document frequency. @@ -420,7 +453,10 @@ def fit(self, dataset): class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, word): """ Transforms a word to its vector representation @@ -435,6 +471,7 @@ def transform(self, word): except Py4JJavaError: raise ValueError("%s not found" % word) + @since('1.2.0') def findSynonyms(self, word, num): """ Find synonyms of a word @@ -450,6 +487,7 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + @since('1.4.0') def getVectors(self): """ Returns a map of words to their vector representations. @@ -457,10 +495,15 @@ def getVectors(self): return self.call("getVectors") @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - return Word2VecModel(jmodel) + model = sc._jvm.Word2VecModelWrapper(jmodel) + return Word2VecModel(model) @ignore_unicode_prefix @@ -502,11 +545,16 @@ class Word2Vec(object): >>> sameModel = Word2VecModel.load(sc, path) >>> model.transform("a") == sameModel.transform("a") True + >>> syms = sameModel.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.2.0 """ def __init__(self): """ @@ -519,6 +567,7 @@ def __init__(self): self.seed = random.randint(0, sys.maxsize) self.minCount = 5 + @since('1.2.0') def setVectorSize(self, vectorSize): """ Sets vector size (default: 100). @@ -526,6 +575,7 @@ def setVectorSize(self, vectorSize): self.vectorSize = vectorSize return self + @since('1.2.0') def setLearningRate(self, learningRate): """ Sets initial learning rate (default: 0.025). @@ -533,6 +583,7 @@ def setLearningRate(self, learningRate): self.learningRate = learningRate return self + @since('1.2.0') def setNumPartitions(self, numPartitions): """ Sets number of partitions (default: 1). Use a small number for @@ -541,6 +592,7 @@ def setNumPartitions(self, numPartitions): self.numPartitions = numPartitions return self + @since('1.2.0') def setNumIterations(self, numIterations): """ Sets number of iterations (default: 1), which should be smaller @@ -549,6 +601,7 @@ def setNumIterations(self, numIterations): self.numIterations = numIterations return self + @since('1.2.0') def setSeed(self, seed): """ Sets random seed. @@ -556,6 +609,7 @@ def setSeed(self, seed): self.seed = seed return self + @since('1.4.0') def setMinCount(self, minCount): """ Sets minCount, the minimum number of times a token must appear @@ -564,6 +618,7 @@ def setMinCount(self, minCount): self.minCount = minCount return self + @since('1.2.0') def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -596,10 +651,13 @@ class ElementwiseProduct(VectorTransformer): >>> rdd = sc.parallelize([a, b]) >>> eprod.transform(rdd).collect() [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + + .. versionadded:: 1.5.0 """ def __init__(self, scalingVector): self.scalingVector = _convert_to_vector(scalingVector) + @since('1.5.0') def transform(self, vector): """ Computes the Hadamard product of the vector. diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b1..2039decc0cb3 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,11 +19,11 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc -__all__ = ['FPGrowth', 'FPGrowthModel'] +__all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,75 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 + """ + + +@inherit_doc +@ignore_unicode_prefix +class PrefixSpanModel(JavaModelWrapper): + """ + .. note:: Experimental + + Model fitted by PrefixSpan + + >>> data = [ + ... [["a", "b"], ["c"]], + ... [["a"], ["c", "b"], ["a", "b"]], + ... [["a", "b"], ["e"]], + ... [["f"]]] + >>> rdd = sc.parallelize(data, 2) + >>> model = PrefixSpan.train(rdd) + >>> sorted(model.freqSequences().collect()) + [FreqSequence(sequence=[[u'a']], freq=3), FreqSequence(sequence=[[u'a'], [u'a']], freq=1), ... + + .. versionadded:: 1.6.0 + """ + + @since("1.6.0") + def freqSequences(self): + """Gets frequence sequences""" + return self.call("getFreqSequences").map(lambda x: PrefixSpan.FreqSequence(x[0], x[1])) + + +class PrefixSpan(object): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: + Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth + ([[http://doi.org/10.1109/ICDE.2001.914830]]). + + .. versionadded:: 1.6.0 + """ + + @classmethod + @since("1.6.0") + def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000): + """ + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param data: The input data set, each element contains a sequnce of itemsets. + :param minSupport: the minimal support level of the sequential pattern, any pattern appears + more than (minSupport * size-of-the-dataset) times will be output (default: `0.1`) + :param maxPatternLength: the maximal length of the sequential pattern, any pattern appears + less than maxPatternLength will be output. (default: `10`) + :param maxLocalProjDBSize: The maximum number of items (including delimiters used in + the internal storage format) allowed in a projected database before local + processing. If a projected database exceeds this size, another + iteration of distributed prefix growth is run. (default: `32000000`) + """ + model = callMLlibFunc("trainPrefixSpanModel", + data, minSupport, maxPatternLength, maxLocalProjDBSize) + return PrefixSpanModel(model) + + class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])): + """ + Represents a (sequence, freq) tuple. + + .. versionadded:: 1.6.0 """ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8..ae9ce5845090 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -25,6 +25,7 @@ import sys import array +import struct if sys.version >= '3': basestring = str @@ -122,6 +123,13 @@ def _format_float_list(l): return [_format_float(x) for x in l] +def _double_to_long_bits(value): + if np.isnan(value): + value = float('nan') + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -232,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError @@ -293,11 +302,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -389,6 +401,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ + return self.array + + @property + def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -404,11 +426,31 @@ def __repr__(self): return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) + if isinstance(other, DenseVector): + return np.array_equal(self.array, other.array) + elif isinstance(other, SparseVector): + if len(self) != other.size: + return False + return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return False def __ne__(self, other): return not self == other + def __hash__(self): + size = len(self) + result = 31 + size + nnz = 0 + i = 0 + while i < size and nnz < 128: + if self.array[i] != 0: + result = 31 * result + i + bits = _double_to_long_bits(self.array[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + def __getattr__(self, item): return getattr(self.array, item) @@ -447,8 +489,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -489,11 +531,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -704,20 +749,14 @@ def __repr__(self): return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): - """ - Test SparseVectors for equality. - - >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v1 == v2 - True - >>> v1 != v2 - False - """ - return (isinstance(other, self.__class__) - and other.size == self.size - and np.array_equal(other.indices, self.indices) - and np.array_equal(other.values, self.values)) + if isinstance(other, SparseVector): + return other.size == self.size and np.array_equal(other.indices, self.indices) \ + and np.array_equal(other.values, self.values) + elif isinstance(other, DenseVector): + if self.size != len(other): + return False + return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return False def __getitem__(self, index): inds = self.indices @@ -725,10 +764,14 @@ def __getitem__(self, index): if not isinstance(index, int): raise TypeError( "Indices must be of type integer, got type %s" % type(index)) + + if index >= self.size or index < -self.size: + raise ValueError("Index %d out of bounds." % index) if index < 0: index += self.size - if index >= self.size or index < 0: - raise ValueError("Index %d out of bounds." % index) + + if (inds.size == 0) or (index > inds.item(-1)): + return 0. insert_index = np.searchsorted(inds, index) row_ind = inds[insert_index] @@ -739,6 +782,19 @@ def __getitem__(self, index): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + result = 31 + self.size + nnz = 0 + i = 0 + while i < len(self.values) and nnz < 128: + if self.values[i] != 0: + result = 31 * result + int(self.indices[i]) + bits = _double_to_long_bits(self.values[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + class Vectors(object): @@ -758,7 +814,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) @@ -841,6 +897,31 @@ def parse(s): def zeros(size): return DenseVector(np.zeros(size)) + @staticmethod + def _equals(v1_indices, v1_values, v2_indices, v2_values): + """ + Check equality between sparse/dense vectors, + v1_indices and v2_indices assume to be strictly increasing. + """ + v1_size = len(v1_values) + v2_size = len(v2_values) + k1 = 0 + k2 = 0 + all_equal = True + while all_equal: + while k1 < v1_size and v1_values[k1] == 0: + k1 += 1 + while k2 < v2_size and v2_values[k2] == 0: + k2 += 1 + + if k1 >= v1_size or k2 >= v2_size: + return k1 >= v1_size and k2 >= v2_size + + all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] + k1 += 1 + k2 += 1 + return all_equal + class Matrix(object): diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py new file mode 100644 index 000000000000..0e7605078863 --- /dev/null +++ b/python/pyspark/mllib/linalg/distributed.py @@ -0,0 +1,921 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Package for distributed linear algebra. +""" + +import sys + +if sys.version >= '3': + long = int + +from py4j.java_gateway import JavaObject + +from pyspark import RDD +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector, Matrix + + +__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', + 'BlockMatrix'] + + +class DistributedMatrix(object): + """ + .. note:: Experimental + + Represents a distributively stored matrix backed by one or + more RDDs. + + """ + def numRows(self): + """Get or compute the number of rows.""" + raise NotImplementedError + + def numCols(self): + """Get or compute the number of cols.""" + raise NotImplementedError + + +class RowMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a row-oriented distributed Matrix with no meaningful + row indices. + + :param rows: An RDD of vectors. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the number of + records in the `rows` RDD. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. + """ + def __init__(self, rows, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java RowMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java RowMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) + >>> mat = RowMatrix(rows) + + >>> mat_diff = RowMatrix(rows) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = RowMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(rows, RDD): + rows = rows.map(_convert_to_vector) + java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) + elif (isinstance(rows, JavaObject) + and rows.getClass().getSimpleName() == "RowMatrix"): + java_matrix = rows + else: + raise TypeError("rows should be an RDD of vectors, got %s" % type(rows)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def rows(self): + """ + Rows of the RowMatrix stored as an RDD of vectors. + + >>> mat = RowMatrix(sc.parallelize([[1, 2, 3], [4, 5, 6]])) + >>> rows = mat.rows + >>> rows.first() + DenseVector([1.0, 2.0, 3.0]) + """ + return self._java_matrix_wrapper.call("rows") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], + ... [7, 8, 9], [10, 11, 12]]) + + >>> mat = RowMatrix(rows) + >>> print(mat.numRows()) + 4 + + >>> mat = RowMatrix(rows, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], + ... [7, 8, 9], [10, 11, 12]]) + + >>> mat = RowMatrix(rows) + >>> print(mat.numCols()) + 3 + + >>> mat = RowMatrix(rows, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + +class IndexedRow(object): + """ + .. note:: Experimental + + Represents a row of an IndexedRowMatrix. + + Just a wrapper over a (long, vector) tuple. + + :param index: The index for the given row. + :param vector: The row in the matrix at the given index. + """ + def __init__(self, index, vector): + self.index = long(index) + self.vector = _convert_to_vector(vector) + + def __repr__(self): + return "IndexedRow(%s, %s)" % (self.index, self.vector) + + +def _convert_to_indexed_row(row): + if isinstance(row, IndexedRow): + return row + elif isinstance(row, tuple) and len(row) == 2: + return IndexedRow(*row) + else: + raise TypeError("Cannot convert type %s into IndexedRow" % type(row)) + + +class IndexedRowMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a row-oriented distributed Matrix with indexed rows. + + :param rows: An RDD of IndexedRows or (long, vector) tuples. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. + """ + def __init__(self, rows, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java IndexedRowMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java IndexedRowMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows) + + >>> mat_diff = IndexedRowMatrix(rows) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = IndexedRowMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(rows, RDD): + rows = rows.map(_convert_to_indexed_row) + # We use DataFrames for serialization of IndexedRows from + # Python, so first convert the RDD to a DataFrame on this + # side. This will convert each IndexedRow to a Row + # containing the 'index' and 'vector' values, which can + # both be easily serialized. We will convert back to + # IndexedRows on the Scala side. + java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(), + long(numRows), int(numCols)) + elif (isinstance(rows, JavaObject) + and rows.getClass().getSimpleName() == "IndexedRowMatrix"): + java_matrix = rows + else: + raise TypeError("rows should be an RDD of IndexedRows or (long, vector) tuples, " + "got %s" % type(rows)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def rows(self): + """ + Rows of the IndexedRowMatrix stored as an RDD of IndexedRows. + + >>> mat = IndexedRowMatrix(sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6])])) + >>> rows = mat.rows + >>> rows.first() + IndexedRow(0, [1.0,2.0,3.0]) + """ + # We use DataFrames for serialization of IndexedRows from + # Java, so we first convert the RDD of rows to a DataFrame + # on the Scala/Java side. Then we map each Row in the + # DataFrame back to an IndexedRow on this side. + rows_df = callMLlibFunc("getIndexedRows", self._java_matrix_wrapper._java_model) + rows = rows_df.map(lambda row: IndexedRow(row[0], row[1])) + return rows + + def numRows(self): + """ + Get or compute the number of rows. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6]), + ... IndexedRow(2, [7, 8, 9]), + ... IndexedRow(3, [10, 11, 12])]) + + >>> mat = IndexedRowMatrix(rows) + >>> print(mat.numRows()) + 4 + + >>> mat = IndexedRowMatrix(rows, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6]), + ... IndexedRow(2, [7, 8, 9]), + ... IndexedRow(3, [10, 11, 12])]) + + >>> mat = IndexedRowMatrix(rows) + >>> print(mat.numCols()) + 3 + + >>> mat = IndexedRowMatrix(rows, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toRowMatrix(self): + """ + Convert this matrix to a RowMatrix. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toRowMatrix() + >>> mat.rows.collect() + [DenseVector([1.0, 2.0, 3.0]), DenseVector([4.0, 5.0, 6.0])] + """ + java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") + return RowMatrix(java_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 0]), + ... IndexedRow(6, [0, 5])]) + >>> mat = IndexedRowMatrix(rows).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 0.0), MatrixEntry(6, 0, 0.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toBlockMatrix() + + >>> # This IndexedRowMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> print(mat.numCols()) + 3 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +class MatrixEntry(object): + """ + .. note:: Experimental + + Represents an entry of a CoordinateMatrix. + + Just a wrapper over a (long, long, float) tuple. + + :param i: The row index of the matrix. + :param j: The column index of the matrix. + :param value: The (i, j)th entry of the matrix, as a float. + """ + def __init__(self, i, j, value): + self.i = long(i) + self.j = long(j) + self.value = float(value) + + def __repr__(self): + return "MatrixEntry(%s, %s, %s)" % (self.i, self.j, self.value) + + +def _convert_to_matrix_entry(entry): + if isinstance(entry, MatrixEntry): + return entry + elif isinstance(entry, tuple) and len(entry) == 3: + return MatrixEntry(*entry) + else: + raise TypeError("Cannot convert type %s into MatrixEntry" % type(entry)) + + +class CoordinateMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a matrix in coordinate format. + + :param entries: An RDD of MatrixEntry inputs or + (long, long, float) tuples. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the max row + index plus one. + """ + def __init__(self, entries, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java CoordinateMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java CoordinateMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries) + + >>> mat_diff = CoordinateMatrix(entries) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = CoordinateMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(entries, RDD): + entries = entries.map(_convert_to_matrix_entry) + # We use DataFrames for serialization of MatrixEntry entries + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each MatrixEntry to a Row + # containing the 'i', 'j', and 'value' values, which can + # each be easily serialized. We will convert back to + # MatrixEntry inputs on the Scala side. + java_matrix = callMLlibFunc("createCoordinateMatrix", entries.toDF(), + long(numRows), long(numCols)) + elif (isinstance(entries, JavaObject) + and entries.getClass().getSimpleName() == "CoordinateMatrix"): + java_matrix = entries + else: + raise TypeError("entries should be an RDD of MatrixEntry entries or " + "(long, long, float) tuples, got %s" % type(entries)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def entries(self): + """ + Entries of the CoordinateMatrix stored as an RDD of + MatrixEntries. + + >>> mat = CoordinateMatrix(sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)])) + >>> entries = mat.entries + >>> entries.first() + MatrixEntry(0, 0, 1.2) + """ + # We use DataFrames for serialization of MatrixEntry entries + # from Java, so we first convert the RDD of entries to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a MatrixEntry on this side. + entries_df = callMLlibFunc("getMatrixEntries", self._java_matrix_wrapper._java_model) + entries = entries_df.map(lambda row: MatrixEntry(row[0], row[1], row[2])) + return entries + + def numRows(self): + """ + Get or compute the number of rows. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(1, 0, 2), + ... MatrixEntry(2, 1, 3.7)]) + + >>> mat = CoordinateMatrix(entries) + >>> print(mat.numRows()) + 3 + + >>> mat = CoordinateMatrix(entries, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(1, 0, 2), + ... MatrixEntry(2, 1, 3.7)]) + + >>> mat = CoordinateMatrix(entries) + >>> print(mat.numCols()) + 2 + + >>> mat = CoordinateMatrix(entries, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toRowMatrix(self): + """ + Convert this matrix to a RowMatrix. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toRowMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, but the ensuing RowMatrix + >>> # will only have 2 rows since there are only entries on 2 + >>> # unique rows. + >>> print(mat.numRows()) + 2 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing RowMatrix + >>> # will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") + return RowMatrix(java_row_matrix) + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # IndexedRowMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # IndexedRowMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toBlockMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # BlockMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +def _convert_to_matrix_block_tuple(block): + if (isinstance(block, tuple) and len(block) == 2 + and isinstance(block[0], tuple) and len(block[0]) == 2 + and isinstance(block[1], Matrix)): + blockRowIndex = int(block[0][0]) + blockColIndex = int(block[0][1]) + subMatrix = block[1] + return ((blockRowIndex, blockColIndex), subMatrix) + else: + raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block)) + + +class BlockMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a distributed matrix in blocks of local matrices. + + :param blocks: An RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that + form this distributed matrix. If multiple blocks + with the same index exist, the results for + operations like add and multiply will be + unpredictable. + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + :param numRows: Number of rows of this matrix. If the supplied + value is less than or equal to zero, the number + of rows will be calculated when `numRows` is + invoked. + :param numCols: Number of columns of this matrix. If the supplied + value is less than or equal to zero, the number + of columns will be calculated when `numCols` is + invoked. + """ + def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java BlockMatrix. + + Publicly, we require that `blocks` be an RDD. However, for + internal usage, `blocks` can also be a Java BlockMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + + >>> mat_diff = BlockMatrix(blocks, 3, 2) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(blocks, RDD): + blocks = blocks.map(_convert_to_matrix_block_tuple) + # We use DataFrames for serialization of sub-matrix blocks + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each sub-matrix block + # tuple to a Row containing the 'blockRowIndex', + # 'blockColIndex', and 'subMatrix' values, which can + # each be easily serialized. We will convert back to + # ((blockRowIndex, blockColIndex), sub-matrix) tuples on + # the Scala side. + java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(), + int(rowsPerBlock), int(colsPerBlock), + long(numRows), long(numCols)) + elif (isinstance(blocks, JavaObject) + and blocks.getClass().getSimpleName() == "BlockMatrix"): + java_matrix = blocks + else: + raise TypeError("blocks should be an RDD of sub-matrix blocks as " + "((int, int), matrix) tuples, got %s" % type(blocks)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def blocks(self): + """ + The RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that form this + distributed matrix. + + >>> mat = BlockMatrix( + ... sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2) + >>> blocks = mat.blocks + >>> blocks.first() + ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0)) + + """ + # We use DataFrames for serialization of sub-matrix blocks + # from Java, so we first convert the RDD of blocks to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a sub-matrix block on this side. + blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model) + blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1])) + return blocks + + @property + def rowsPerBlock(self): + """ + Number of rows that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.rowsPerBlock + 3 + """ + return self._java_matrix_wrapper.call("rowsPerBlock") + + @property + def colsPerBlock(self): + """ + Number of columns that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.colsPerBlock + 2 + """ + return self._java_matrix_wrapper.call("colsPerBlock") + + @property + def numRowBlocks(self): + """ + Number of rows of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numRowBlocks + 2 + """ + return self._java_matrix_wrapper.call("numRowBlocks") + + @property + def numColBlocks(self): + """ + Number of columns of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numColBlocks + 1 + """ + return self._java_matrix_wrapper.call("numColBlocks") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numRows()) + 6 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numCols()) + 2 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def add(self, other): + """ + Adds two block matrices together. The matrices must have the + same size and matching `rowsPerBlock` and `colsPerBlock` values. + If one of the sub matrix blocks that are being added is a + SparseMatrix, the resulting sub matrix block will also be a + SparseMatrix, even if it is being added to a DenseMatrix. If + two dense sub matrix blocks are added, the output block will + also be a DenseMatrix. + + >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)]) + >>> mat1 = BlockMatrix(blocks1, 3, 2) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.add(mat2).toLocalMatrix() + DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0) + + >>> mat1.add(mat3).toLocalMatrix() + DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + + def multiply(self, other): + """ + Left multiplies this BlockMatrix by `other`, another + BlockMatrix. The `colsPerBlock` of this matrix must equal the + `rowsPerBlock` of `other`. If `other` contains any SparseMatrix + blocks, they will have to be converted to DenseMatrix blocks. + The output BlockMatrix will only consist of DenseMatrix blocks. + This may cause some performance issues until support for + multiplying two sparse matrices is added. + + >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12]) + >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)]) + >>> mat1 = BlockMatrix(blocks1, 2, 3) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.multiply(mat2).toLocalMatrix() + DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0) + + >>> mat1.multiply(mat3).toLocalMatrix() + DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + + def toLocalMatrix(self): + """ + Collect the distributed matrix on the driver as a DenseMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing DenseMatrix will also have 6 rows. + >>> print(mat.numRows) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 + >>> # columns. The ensuing DenseMatrix will also have 2 columns. + >>> print(mat.numCols) + 2 + """ + return self._java_matrix_wrapper.call("toLocalMatrix") + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing IndexedRowMatrix will also have 6 rows. + >>> print(mat.numRows()) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 columns. + >>> # The ensuing IndexedRowMatrix will also have 2 columns. + >>> print(mat.numCols()) + 2 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])), + ... ((1, 0), Matrices.dense(1, 2, [7, 8]))]) + >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + + +def _test(): + import doctest + from pyspark import SparkContext + from pyspark.sql import SQLContext + from pyspark.mllib.linalg import Matrices + import pyspark.mllib.linalg.distributed + globs = pyspark.mllib.linalg.distributed.__dict__.copy() + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + globs['sqlContext'] = SQLContext(globs['sc']) + globs['Matrices'] = Matrices + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 06fbc0eb6aef..6a3c643b6641 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,6 +21,7 @@ from functools import wraps +from pyspark import since from pyspark.mllib.common import callMLlibFunc @@ -39,9 +40,12 @@ class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + + .. versionadded:: 1.1.0 """ @staticmethod + @since("1.1.0") def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -72,6 +76,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.1.0") def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -100,6 +105,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.3.0") def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the log normal @@ -132,6 +138,7 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): size, numPartitions, seed) @staticmethod + @since("1.1.0") def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson @@ -158,6 +165,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Exponential @@ -184,6 +192,7 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Gamma @@ -216,6 +225,7 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -241,6 +251,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -266,6 +277,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -300,6 +312,7 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed @staticmethod @toArray + @since("1.1.0") def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -330,6 +343,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -360,6 +374,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No @staticmethod @toArray + @since("1.3.0") def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 506ca2151cce..93e47a797f49 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,7 +18,7 @@ import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) + + .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -74,25 +76,37 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 >>> model.productFeatures().collect() [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 + + >>> products_for_users = model.recommendProductsForUsers(1).collect() + >>> len(products_for_users) + 2 + >>> products_for_users[0] + (1, (Rating(user=1, product=2, rating=...),)) + + >>> users_for_products = model.recommendUsersForProducts(1).collect() + >>> len(users_for_products) + 2 + >>> users_for_products[0] + (1, (Rating(user=2, product=1, rating=...),)) >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) >>> model = ALS.train(df, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) @@ -111,13 +125,17 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, user, product): """ Predicts rating for the given user and product. """ return self._java_model.predict(int(user), int(product)) + @since("0.9.0") def predictAll(self, user_product): """ Returns a list of predicted ratings for input user and product pairs. @@ -128,6 +146,7 @@ def predictAll(self, user_product): user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) + @since("1.2.0") def userFeatures(self): """ Returns a paired RDD, where the first element is the user and the @@ -135,6 +154,7 @@ def userFeatures(self): """ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.2.0") def productFeatures(self): """ Returns a paired RDD, where the first element is the product and the @@ -142,6 +162,7 @@ def productFeatures(self): """ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.4.0") def recommendUsers(self, product, num): """ Recommends the top "num" number of users for a given product and returns a list @@ -149,6 +170,7 @@ def recommendUsers(self, product, num): """ return list(self.call("recommendUsers", product, num)) + @since("1.4.0") def recommendProducts(self, user, num): """ Recommends the top "num" number of products for a given user and returns a list @@ -156,18 +178,38 @@ def recommendProducts(self, user, num): """ return list(self.call("recommendProducts", user, num)) + def recommendProductsForUsers(self, num): + """ + Recommends top "num" products for all users. The number returned may be less than this. + """ + return self.call("wrappedRecommendProductsForUsers", num) + + def recommendUsersForProducts(self, num): + """ + Recommends top "num" users for all products. The number returned may be less than this. + """ + return self.call("wrappedRecommendUsersForProducts", num) + @property + @since("1.4.0") def rank(self): + """Rank for the features in this model""" return self.call("rank") @classmethod + @since("1.3.1") def load(cls, sc, path): + """Load a model from the given path""" model = cls._load_java(sc, path) wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) class ALS(object): + """Alternating Least Squares matrix factorization + + .. versionadded:: 0.9.0 + """ @classmethod def _prepare(cls, ratings): @@ -188,15 +230,31 @@ def _prepare(cls, ratings): return ratings @classmethod + @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of ratings given by users to some products, + in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + product of two lower-rank matrices of a given rank (number of features). To solve for these + features, we run a given number of iterations of ALS. This is done using a level of + parallelism given by `blocks`. + """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod + @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of 'implicit preferences' given by users + to some products, in the form of (userID, productID, preference) pairs. We approximate the + ratings matrix as the product of two lower-rank matrices of a given rank (number of + features). To solve for these features, we run a given number of iterations of ALS. + This is done using a level of parallelism given by `blocks`. + """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5b7afc15ddfb..13b3397501c0 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark import RDD +from pyspark import RDD, since from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector @@ -28,7 +28,8 @@ 'LinearRegressionModel', 'LinearRegressionWithSGD', 'RidgeRegressionModel', 'RidgeRegressionWithSGD', 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', - 'IsotonicRegression'] + 'IsotonicRegression', 'StreamingLinearAlgorithm', + 'StreamingLinearRegressionWithSGD'] class LabeledPoint(object): @@ -42,6 +43,8 @@ class LabeledPoint(object): column matrix) Note: 'label' and 'features' are accessible as class attributes. + + .. versionadded:: 1.0.0 """ def __init__(self, label, features): @@ -65,6 +68,8 @@ class LinearModel(object): :param weights: Weights computed for every feature. :param intercept: Intercept computed for this model. + + .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept): @@ -72,11 +77,15 @@ def __init__(self, weights, intercept): self._intercept = float(intercept) @property + @since("1.0.0") def weights(self): + """Weights computed for every feature.""" return self._coeff @property + @since("1.0.0") def intercept(self): + """Intercept computed for this model.""" return self._intercept def __repr__(self): @@ -93,8 +102,11 @@ class LinearRegressionModelBase(LinearModel): True >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, x): """ Predict the value of the dependent variable given a vector or @@ -162,14 +174,20 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a LinearRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a LinearRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -198,17 +216,31 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): + """ + Train a linear regression model with no regularization using Stochastic Gradient Descent. + This solves the least squares regression formulation + f(weights) = 1/n ||A weights-y||^2^ + (which is the mean squared error). + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/(2n) ||A weights - y||^2, + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -242,11 +274,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept), bool(validateData)) + regType, bool(intercept), bool(validateData), + float(convergenceTol)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -307,14 +342,20 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a LassoModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a LassoModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -324,17 +365,30 @@ def load(cls, sc, path): class LassoWithSGD(object): + """ + Train a regression model with L1-regularization using Stochastic Gradient Descent. + This solves the l1-regularized least squares regression formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -358,11 +412,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -424,14 +480,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a RidgeRegressionMode.""" java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a RidgeRegressionMode.""" java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -441,17 +503,30 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): + """ + Train a regression model with L2-regularization using Stochastic Gradient Descent. + This solves the l2-regularized least squares regression formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -475,11 +550,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) @@ -517,6 +594,8 @@ class IsotonicRegressionModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.4.0 """ def __init__(self, boundaries, predictions, isotonic): @@ -524,6 +603,7 @@ def __init__(self, boundaries, predictions, isotonic): self.predictions = predictions self.isotonic = isotonic + @since("1.4.0") def predict(self, x): """ Predict labels for provided features. @@ -548,7 +628,9 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) return np.interp(x, self.boundaries, self.predictions) + @since("1.4.0") def save(self, sc, path): + """Save a IsotonicRegressionModel.""" java_boundaries = _py2java(sc, self.boundaries.tolist()) java_predictions = _py2java(sc, self.predictions.tolist()) java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( @@ -556,7 +638,9 @@ def save(self, sc, path): java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a IsotonicRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( sc._jsc.sc(), path) py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray() @@ -565,8 +649,29 @@ def load(cls, sc, path): class IsotonicRegression(object): + """ + Isotonic regression. + Currently implemented using parallelized pool adjacent violators algorithm. + Only univariate (single feature) algorithm supported. + + Sequential PAV implementation based on: + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + + Sequential PAV parallelization based on: + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + "An approach to parallelizing isotonic regression." + Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + + @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + + .. versionadded:: 1.4.0 + """ @classmethod + @since("1.4.0") def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. @@ -584,10 +689,13 @@ class StreamingLinearAlgorithm(object): Base class that has to be inherited by any StreamingLinearAlgorithm. Prevents reimplementation of methods predictOn and predictOnValues. + + .. versionadded:: 1.5.0 """ def __init__(self, model): self._model = model + @since("1.5.0") def latestModel(self): """ Returns the latest model. @@ -602,6 +710,7 @@ def _validate(self, dstream): raise ValueError( "Model must be intialized using setInitialWeights") + @since("1.5.0") def predictOn(self, dstream): """ Make predictions on a dstream. @@ -611,6 +720,7 @@ def predictOn(self, dstream): self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) + @since("1.5.0") def predictOnValues(self, dstream): """ Make predictions on a keyed dstream. @@ -624,25 +734,40 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param: stepSize Step size for each iteration of gradient descent. - :param: numIterations Total number of iterations run. - :param: miniBatchFraction Fraction of data on which SGD is run for each - iteration. + Train or predict a linear regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) + + .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLinearRegressionWithSGD, self).__init__( model=self._model) + @since("1.5.0") def setInitialWeights(self, initialWeights): """ Set the initial value of weights. @@ -653,6 +778,7 @@ def setInitialWeights(self, initialWeights): self._model = LinearRegressionModel(initialWeights, 0) return self + @since("1.5.0") def trainOn(self, dstream): """Train the model on the incoming dstream.""" self._validate(dstream) @@ -663,7 +789,7 @@ def update(rdd): self._model = LinearRegressionWithSGD.train( rdd, self.numIterations, self.stepSize, self.miniBatchFraction, self._model.weights, - self._model.intercept) + intercept=self._model.intercept, convergenceTol=self.convergenceTol) dstream.foreachRDD(update) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3f5a02af12e3..f8e8e0e0adbe 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -31,6 +31,13 @@ from numpy import sum as array_sum from py4j.protocol import Py4JJavaError +try: + import xmlrunner +except ImportError: + xmlrunner = None + +if sys.version > '3': + basestring = str if sys.version_info[:2] <= (2, 6): try: @@ -86,9 +93,42 @@ def tearDown(self): self.ssc.stop(False) @staticmethod - def _ssc_wait(start_time, end_time, sleep_time): - while time() - start_time < end_time: + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) def _squared_distance(a, b): @@ -130,13 +170,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -145,18 +185,50 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) def test_conversion(self): # numpy arrays should be automatically upcast to float64 @@ -169,25 +241,37 @@ def test_conversion(self): self.assertTrue(dv.array.dtype == 'float64') def test_sparse_vector_indexing(self): - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) - for ind in [4, -5]: + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(ValueError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(ValueError, empty.__getitem__, ind) + def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -240,11 +324,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -257,13 +341,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -271,11 +355,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -284,18 +368,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -344,8 +428,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -375,8 +459,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -388,7 +472,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -643,18 +727,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -662,7 +746,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -671,8 +755,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -691,8 +775,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -916,12 +1000,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -934,7 +1018,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -976,8 +1060,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -985,9 +1069,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -999,10 +1083,13 @@ def test_accuracy_for_single_center(self): [self.sc.parallelize(batch, 1) for batch in batches]) stkm.trainOn(input_stream) - t = time() self.ssc.start() - self._ssc_wait(t, 10.0, 0.01) - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + + def condition(): + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + realCenters = array_sum(array(centers), axis=0) for i in range(5): modelCenters = stkm.latestModel().centers[0][i] @@ -1027,7 +1114,7 @@ def test_trainOn_model(self): stkm.setInitialCenters( centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - # Create a toy dataset by setting a tiny offest for each point. + # Create a toy dataset by setting a tiny offset for each point. offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] batches = [] for offset in offsets: @@ -1037,14 +1124,15 @@ def test_trainOn_model(self): batches = [self.sc.parallelize(batch, 1) for batch in batches] input_stream = self.ssc.queueStream(batches) stkm.trainOn(input_stream) - t = time() self.ssc.start() # Give enough time to train the model. - self._ssc_wait(t, 6.0, 0.01) - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) def test_predictOn_model(self): """Test that the model predicts correctly on toy data.""" @@ -1066,10 +1154,13 @@ def update(rdd): result.append(rdd_collect) predict_val.foreachRDD(update) - t = time() self.ssc.start() - self._ssc_wait(t, 6.0, 0.01) - self.assertEquals(result, [[0], [1], [2], [3]]) + + def condition(): + self.assertEqual(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) def test_trainOn_predictOn(self): """Test that prediction happens on the updated model.""" @@ -1095,10 +1186,13 @@ def collect(rdd): predict_stream = stkm.predictOn(input_stream) predict_stream.foreachRDD(collect) - t = time() self.ssc.start() - self._ssc_wait(t, 6.0, 0.01) - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) class LinearDataGeneratorTests(MLlibTestCase): @@ -1156,11 +1250,14 @@ def test_parameter_accuracy(self): slr.setInitialWeights([0.0]) slr.trainOn(input_stream) - t = time() self.ssc.start() - self._ssc_wait(t, 20.0, 0.01) - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) def test_convergence(self): """ @@ -1179,13 +1276,18 @@ def test_convergence(self): input_stream.foreachRDD( lambda x: models.append(slr.latestModel().weights[0])) - t = time() self.ssc.start() - self._ssc_wait(t, 15.0, 0.01) + + def condition(): + self.assertEqual(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + t_models = array(models) diff = t_models[1:] - t_models[:-1] - - # Test that weights improve with a small tolerance, + # Test that weights improve with a small tolerance self.assertTrue(all(diff >= -0.1)) self.assertTrue(array_sum(diff > 0) > 1) @@ -1208,9 +1310,13 @@ def test_predictions(self): predict_stream = slr.predictOnValues(input_stream) true_predicted = [] predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - t = time() self.ssc.start() - self._ssc_wait(t, 5.0, 0.01) + + def condition(): + self.assertEqual(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) # Test that the accuracy error is no more than 0.4 on each batch. for batch in true_predicted: @@ -1242,12 +1348,17 @@ def collect_errors(rdd): ps = slr.predictOnValues(predict_stream) ps.foreachRDD(lambda x: collect_errors(x)) - t = time() self.ssc.start() - self._ssc_wait(t, 20.0, 0.01) - # Test that the improvement in error is atleast 0.3 - self.assertTrue(errors[1] - errors[-1] > 0.3) + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): @@ -1274,13 +1385,16 @@ def test_parameter_accuracy(self): batches.append(sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) - t = time() slr.trainOn(input_stream) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) def test_parameter_convergence(self): """Test that the model parameters improve with streaming data.""" @@ -1298,13 +1412,18 @@ def test_parameter_convergence(self): input_stream = self.ssc.queueStream(batches) input_stream.foreachRDD( lambda x: model_weights.append(slr.latestModel().weights[0])) - t = time() slr.trainOn(input_stream) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - model_weights = array(model_weights) - diff = model_weights[1:] - model_weights[:-1] + def condition(): + self.assertEqual(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] self.assertTrue(all(diff >= -0.1)) def test_prediction(self): @@ -1323,13 +1442,18 @@ def test_prediction(self): sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) - t = time() output_stream = slr.predictOnValues(input_stream) samples = [] output_stream.foreachRDD(lambda x: samples.append(x.collect())) self.ssc.start() - self._ssc_wait(t, 5, 0.01) + + def condition(): + self.assertEqual(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) # Test that mean absolute error on each batch is less than 0.1 for batch in samples: @@ -1350,22 +1474,27 @@ def test_train_prediction(self): predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] - mean_absolute_errors = [] + errors = [] def func(rdd): true, predicted = zip(*rdd.collect()) - mean_absolute_errors.append(mean(abs(true) - abs(predicted))) + errors.append(mean(abs(true) - abs(predicted))) - model_weights = [] input_stream = self.ssc.queueStream(batches) output_stream = self.ssc.queueStream(predict_batches) - t = time() slr.trainOn(input_stream) output_stream = slr.predictOnValues(output_stream) output_stream.foreachRDD(func) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) class MLUtilsTests(MLlibTestCase): @@ -1413,7 +1542,10 @@ def test_load_vectors(self): if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d..0001b60093a6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 916de2d6fcdb..39bc6586dd58 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -300,6 +323,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, :param: seed Random Seed :param: eps Used to scale the noise. If eps is set high, the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints """ weights = [float(weight) for weight in weights] @@ -310,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fa8e0a0574a6..00bb9a62e904 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -688,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -697,10 +695,10 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix - def pipe(self, command, env={}, checkCode=False): + def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. @@ -709,6 +707,9 @@ def pipe(self, command, env={}, checkCode=False): :param checkCode: whether or not to check the return value of the shell command. """ + if env is None: + env = dict() + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) @@ -1538,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1738,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1758,7 +1760,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, In addition, users can control the partitioning of the output RDD. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def f(x): return x >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] @@ -1767,28 +1768,26 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1802,9 +1801,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1819,16 +1818,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1854,23 +1851,20 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() @@ -2021,7 +2015,7 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions) + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): @@ -2189,6 +2183,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 411b4dbf481f..2a1326947f4f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,7 @@ def _hack_namedtuple(cls): def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True return cls diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca7..e974cda9fc3e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index ad9c891ba1c0..98eaf52866d2 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -44,21 +44,6 @@ from __future__ import absolute_import -def since(version): - """ - A decorator that annotates a function to append the version of Spark the function was added. - """ - import re - indent_p = re.compile(r'\n( +)') - - def deco(f): - indents = indent_p.findall(f.__doc__) - indent = ' ' * (min(len(m) for m in indents) if indents else 0) - f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) - return f - return deco - - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0a85da7443d3..81fd4e782628 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,14 +16,15 @@ # import sys +import warnings if sys.version >= '3': basestring = str long = int +from pyspark import since from pyspark.context import SparkContext from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since from pyspark.sql.types import * __all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", @@ -60,6 +61,18 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) +def _to_list(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM (Scala) List of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toList(cols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): @@ -78,6 +91,17 @@ def _(self): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -138,6 +162,8 @@ def __init__(self, jc): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") @@ -213,6 +239,9 @@ def __getattr__(self, item): raise AttributeError(item) return self.getField(item) + def __iter__(self): + raise TypeError("Column is not iterable") + # string methods rlike = _bin_op("rlike") like = _bin_op("like") @@ -254,12 +283,29 @@ def inSet(self, *cols): [Row(age=5, name=u'Bob')] >>> df[df.age.inSet([1, 2, 3])].collect() [Row(age=2, name=u'Alice')] + + .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. + """ + warnings.warn("inSet is deprecated. Use isin() instead.") + return self.isin(*cols) + + @ignore_unicode_prefix + @since(1.5) + def isin(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.isin("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.isin([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) return Column(jc) # order @@ -300,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 917de24f3536..b05aa2f5c4cd 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -26,9 +26,9 @@ from py4j.protocol import Py4JError +from pyspark import since from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame @@ -39,7 +39,7 @@ try: import pandas has_pandas = True -except ImportError: +except Exception: has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] @@ -75,6 +75,8 @@ class SQLContext(object): SQLContext in the JVM, instead we make all calls to this object. """ + _instantiatedContext = None + @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. @@ -99,6 +101,8 @@ def __init__(self, sparkContext, sqlContext=None): self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) install_exception_handler() + if SQLContext._instantiatedContext is None: + SQLContext._instantiatedContext = self @property def _ssql_ctx(self): @@ -111,6 +115,29 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + @classmethod + @since(1.6) + def getOrCreate(cls, sc): + """ + Get the existing SQLContext or create a new one with given SparkContext. + + :param sc: SparkContext + """ + if cls._instantiatedContext is None: + jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc()) + cls(sc, jsqlContext) + return cls._instantiatedContext + + @since(1.6) + def newSession(self): + """ + Returns a new SQLContext as new session, that has separate SQLConf, + registered temporary tables and UDFs, but shared SparkContext and + table cache. + """ + jsqlContext = self._ssql_ctx.newSession() + return self.__class__(self._sc, jsqlContext) + @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. @@ -168,14 +195,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) @@ -291,13 +319,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data @@ -394,7 +416,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] - >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] """ if isinstance(data, DataFrame): @@ -423,6 +445,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") + @since(1.6) + def dropTempTable(self, tableName): + """ Remove the temp table from catalog. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> sqlContext.dropTempTable("table1") + """ + self._ssql_ctx.dropTempTable(tableName) + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0f3480c23918..78ab475eb466 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,13 +26,13 @@ else: from itertools import imap as map +from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string -from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -212,8 +212,8 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ - NativeMethodAccessorImpl.java:... + == Physical Plan == + Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -224,12 +224,11 @@ def explain(self, extended=False): ... == Physical Plan == ... - == RDD == """ if extended: print(self._jdf.queryExecution().toString()) else: - print(self._jdf.queryExecution().executedPlan().toString()) + print(self._jdf.queryExecution().simpleString()) @since(1.3) def isLocal(self): @@ -278,7 +277,7 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @@ -302,7 +301,10 @@ def take(self, num): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - return self.limit(num).collect() + with SCCallSiteSync(self._sc) as css: + port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe( + self._jdf, num) + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) @@ -420,6 +422,67 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) + def repartition(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is hash partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + .. versionchanged:: 1.6 + Added optional arguments to specify the partitioning columns. Also made numPartitions + optional if partitioning columns are specified. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + >>> data = df.unionAll(df).repartition("age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 2|Alice| + | 5| Bob| + | 5| Bob| + +---+-----+ + >>> data = data.repartition(7, "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data.rdd.getNumPartitions() + 7 + >>> data = data.repartition("name", "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + else: + return DataFrame( + self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions, ) + cols + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -434,7 +497,7 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() - 1 + 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) @@ -461,8 +524,8 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ |key|count| +---+-----+ - | 0| 3| - | 1| 8| + | 0| 5| + | 1| 9| +---+-----+ """ @@ -496,7 +559,7 @@ def randomSplit(self, weights, seed=None): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property @@ -568,9 +631,12 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - - if isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + elif isinstance(on[0], basestring): + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: @@ -584,6 +650,26 @@ def join(self, other, on=None, how=None): jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) + @since(1.6) + def sortWithinPartitions(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sortWithinPartitions("age", ascending=False).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): @@ -608,22 +694,7 @@ def sort(self, *cols, **kwargs): >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ - if not cols: - raise ValueError("should sort by at least one column") - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - jcols = [_to_java_column(c) for c in cols] - ascending = kwargs.get('ascending', True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() - for asc, jc in zip(ascending, jcols)] - else: - raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) - - jdf = self._jdf.sort(self._jseq(jcols)) + jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) orderBy = sort @@ -645,6 +716,25 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + def _sort_cols(self, cols, kwargs): + """ Return a JVM Seq of Columns that describes the sort order + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + return self._jseq(jcols) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. @@ -656,25 +746,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+---+ - |summary|age| - +-------+---+ - | count| 2| - | mean|3.5| - | stddev|1.5| - | min| 2| - | max| 5| - +-------+---+ + +-------+------------------+ + |summary| age| + +-------+------------------+ + | count| 2| + | mean| 3.5| + | stddev|2.1213203435596424| + | min| 2| + | max| 5| + +-------+------------------+ >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] @@ -725,8 +815,6 @@ def __getitem__(self, item): [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): - if item not in self.columns: - raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): @@ -778,7 +866,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] @@ -929,6 +1017,8 @@ def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. + :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. + >>> from pyspark.sql import Row >>> df = sc.parallelize([ \ Row(name='Alice', age=5, height=80), \ @@ -1207,7 +1297,9 @@ def freqItems(self, cols, support=None): @ignore_unicode_prefix @since(1.3) def withColumn(self, colName, col): - """Returns a new :class:`DataFrame` by adding a column. + """ + Returns a new :class:`DataFrame` by adding a column or replacing the + existing column that has the same name. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. @@ -1215,7 +1307,8 @@ def withColumn(self, colName, col): >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ - return self.select('*', col.alias(colName)) + assert isinstance(col, Column), "col should be Column" + return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) @ignore_unicode_prefix @since(1.3) @@ -1228,10 +1321,7 @@ def withColumnRenamed(self, existing, new): >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ - cols = [Column(_to_java_column(c)).alias(new) - if c == existing else c - for c in self.columns] - return self.select(*cols) + return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) @since(1.4) @ignore_unicode_prefix @@ -1261,6 +1351,18 @@ def drop(self, col): raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix + def toDF(self, *cols): + """Returns a new class:`DataFrame` that with new specified column names + + :param cols: list of new column names (string) + + >>> df.toDF('f1', 'f2').collect() + [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] + """ + jdf = self._jdf.toDF(self._jseq(cols)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a73ecc7d9336..90625949f747 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,47 +24,12 @@ if sys.version < "3": from itertools import imap as map -from pyspark import SparkContext +from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql import since from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq - - -__all__ = [ - 'array', - 'approxCountDistinct', - 'bin', - 'coalesce', - 'countDistinct', - 'explode', - 'format_number', - 'length', - 'log2', - 'md5', - 'monotonicallyIncreasingId', - 'rand', - 'randn', - 'regexp_extract', - 'regexp_replace', - 'sha1', - 'sha2', - 'size', - 'sort_array', - 'sparkPartitionId', - 'struct', - 'udf', - 'when'] - -__all__ += ['lag', 'lead', 'ntile'] - -__all__ += [ - 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', - 'year', 'quarter', 'month', 'hour', 'minute', 'second', - 'dayofmonth', 'dayofyear', 'weekofyear'] - -__all__ += ['soundex', 'substring', 'substring_index'] +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -157,6 +122,24 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_1_6 = { + # unary math functions + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' +} + # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + @@ -167,18 +150,18 @@ def _(): _window_functions = { 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", + """.. note:: Deprecated in 1.6, use row_number instead.""", + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", 'denseRank': + """.. note:: Deprecated in 1.6, use dense_rank instead.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -189,14 +172,14 @@ def _(): This is equivalent to the RANK function in SQL.""", 'cumeDist': + """.. note:: Deprecated in 1.6, use cume_dist instead.""", + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", + i.e. the fraction of rows that are below the current row.""", 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + """.. note:: Deprecated in 1.6, use percent_rank instead.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -206,32 +189,10 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) +for _name, _doc in _functions_1_6.items(): + globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc -__all__ += _functions.keys() -__all__ += _functions_1_4.keys() -__all__ += _binary_mathfunctions.keys() -__all__ += _window_functions.keys() -__all__.sort() - - -@since(1.4) -def array(*cols): - """Creates a new array column. - - :param cols: list of column names (string) or list of :class:`Column` expressions that have - the same data type. - - >>> df.select(array('age', 'age').alias("arr")).collect() - [Row(arr=[2, 2]), Row(arr=[5, 5])] - >>> df.select(array([df.age, df.age]).alias("arr")).collect() - [Row(arr=[2, 2]), Row(arr=[5, 5])] - """ - sc = SparkContext._active_spark_context - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) - return Column(jc) @since(1.3) @@ -249,17 +210,12 @@ def approxCountDistinct(col, rsd=None): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def bin(col): - """Returns the string representation of the binary value of the given column. +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" - >>> df.select(bin(df.age).alias('c')).collect() - [Row(c=u'10'), Row(c=u'101')] - """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.bin(_to_java_column(col)) - return Column(jc) + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) @since(1.4) @@ -299,6 +255,22 @@ def coalesce(*cols): return Column(jc) +@since(1.6) +def corr(col1, col2): + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. + + >>> a = [x * x - 2 * x + 3.5 for x in range(20)] + >>> b = range(20) + >>> corrDf = sqlContext.createDataFrame(zip(a, b)) + >>> corrDf = corrDf.agg(corr(corrDf._1, corrDf._2).alias('c')) + >>> corrDf.selectExpr('abs(c - 0.9572339139475857) < 1e-16 as t').collect() + [Row(t=True)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) + + @since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -314,84 +286,48 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def explode(col): - """Returns a new row for each element in the given array or map. - - >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) - >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() - [Row(anInt=1), Row(anInt=2), Row(anInt=3)] - - >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() - +---+-----+ - |key|value| - +---+-----+ - | a| b| - +---+-----+ +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.explode(_to_java_column(col)) - return Column(jc) + return Column(sc._jvm.functions.input_file_name()) -@ignore_unicode_prefix -@since(1.5) -def levenshtein(left, right): - """Computes the Levenshtein distance of the two given strings. +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. - >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) - >>> df0.select(levenshtein('l', 'r').alias('d')).collect() - [Row(d=3)] + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) - return Column(jc) + return Column(sc._jvm.functions.isnan(_to_java_column(col))) -@ignore_unicode_prefix -@since(1.5) -def regexp_extract(str, pattern, idx): - """Extract a specific(idx) group identified by a java regex, from the specified string column. +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() - [Row(d=u'100')] + >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) - return Column(jc) - + return Column(sc._jvm.functions.isnull(_to_java_column(col))) -@ignore_unicode_prefix -@since(1.5) -def regexp_replace(str, pattern, replacement): - """Replace all substrings of the specified string value that match regexp with rep. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect() - [Row(d=u'##-##')] +@since(1.4) +def monotonicallyIncreasingId(): """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def md5(col): - """Calculates the MD5 digest and returns the value as a 32 character hex string. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() - [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.md5(_to_java_column(col)) - return Column(jc) + return monotonically_increasing_id() -@since(1.4) -def monotonicallyIncreasingId(): +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -404,11 +340,25 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() + [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -416,7 +366,7 @@ def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.rand(seed) else: jc = sc._jvm.functions.rand() @@ -428,70 +378,24 @@ def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.randn(seed) else: jc = sc._jvm.functions.randn() return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def hex(col): - """Computes hex value of the given column, which could be StringType, - BinaryType, IntegerType or LongType. - - >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() - [Row(hex(a)=u'414243', hex(b)=u'3')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.hex(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def unhex(col): - """Inverse of hex. Interprets each pair of characters as a hexadecimal number - and converts to the byte representation of number. - - >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() - [Row(unhex(a)=bytearray(b'ABC'))] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.unhex(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix @since(1.5) -def sha1(col): - """Returns the hex string result of SHA-1. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() - [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] +def round(col, scale=0): """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha1(_to_java_column(col)) - return Column(jc) - - -@ignore_unicode_prefix -@since(1.5) -def sha2(col, numBits): - """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, - and SHA-512). The numBits indicates the desired bit length of the result, which must have a - value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + Round the value of `e` to `scale` decimal places if `scale` >= 0 + or at integral part when `scale` < 0. - >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() - >>> digests[0] - Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') - >>> digests[1] - Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect() + [Row(r=2.5)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) - return Column(jc) + return Column(sc._jvm.functions.round(_to_java_column(col), scale)) @since(1.5) @@ -502,8 +406,7 @@ def shiftLeft(col, numBits): [Row(r=42)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) - return Column(jc) + return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)) @since(1.5) @@ -522,8 +425,8 @@ def shiftRight(col, numBits): def shiftRightUnsigned(col, numBits): """Unsigned shift the the given value numBits right. - >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ - .collect() + >>> df = sqlContext.createDataFrame([(-42,)], ['a']) + >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() [Row(r=9223372036854775787)] """ sc = SparkContext._active_spark_context @@ -533,55 +436,36 @@ def shiftRightUnsigned(col, numBits): @since(1.4) def sparkPartitionId(): + """ + .. note:: Deprecated in 1.6, use spark_partition_id instead. + """ + return spark_partition_id() + + +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) +@since(1.5) def expr(str): """Parses the expression string into the column that it represents >>> df.select(expr("length(name)")).collect() - [Row('length(name)=5), Row('length(name)=3)] + [Row(length(name)=5), Row(length(name)=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.expr(str)) -@ignore_unicode_prefix -@since(1.5) -def length(col): - """Calculates the length of a string or binary expression. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.length(_to_java_column(col))) - - -@ignore_unicode_prefix -@since(1.5) -def format_number(col, d): - """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, - and returns the result as a string. - - :param col: the column name of the numeric value to be formatted - :param d: the N decimal places - - >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() - [Row(v=u'5.0000')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) - - @ignore_unicode_prefix @since(1.4) def struct(*cols): @@ -601,6 +485,38 @@ def struct(*cols): return Column(jc) +@since(1.5) +def greatest(*cols): + """ + Returns the greatest value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() + [Row(greatest=4)] + """ + if len(cols) < 2: + raise ValueError("greatest should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def least(*cols): + """ + Returns the least value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() + [Row(least=1)] + """ + if len(cols) < 2: + raise ValueError("least should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column))) + + @since(1.4) def when(condition, value): """Evaluates a list of conditions and returns one of multiple possible result expressions. @@ -654,6 +570,35 @@ def log2(col): return Column(sc._jvm.functions.log2(_to_java_column(col))) +@since(1.5) +@ignore_unicode_prefix +def conv(col, fromBase, toBase): + """ + Convert a number in a string column from one base to another. + + >>> df = sqlContext.createDataFrame([("010101",)], ['n']) + >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() + [Row(hex=u'15')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase)) + + +@since(1.5) +def factorial(col): + """ + Computes the factorial of the given value. + + >>> df = sqlContext.createDataFrame([(5,)], ['n']) + >>> df.select(factorial(df.n).alias('f')).collect() + [Row(f=120)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.factorial(_to_java_column(col))) + + +# --------------- Window functions ------------------------ + @since(1.4) def lag(col, count=1, default=None): """ @@ -691,9 +636,10 @@ def lead(col, count=1, default=None): @since(1.4) def ntile(n): """ - Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in - a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will - get 2, the third row will get 3, and the fourth row will get 1... + Window function: returns the ntile group id (from 1 to `n` inclusive) + in an ordered window partition. For example, if `n` is 4, the first + quarter of the rows will get value 1, the second quarter will get 2, + the third quarter will get 3, and the last quarter will get 4. This is equivalent to the NTILE function in SQL. @@ -703,9 +649,28 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +# ---------------------- Date/Timestamp functions ------------------------------ + +@since(1.5) +def current_date(): + """ + Returns the current date as a date column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_date()) + + +def current_timestamp(): + """ + Returns the current timestamp as a timestamp column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_timestamp()) + + @ignore_unicode_prefix @since(1.5) -def date_format(dateCol, format): +def date_format(date, format): """ Converts a date/timestamp/string to a value of string in the format specified by the date format given by the second argument. @@ -721,7 +686,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) + return Column(sc._jvm.functions.date_format(_to_java_column(date), format)) @since(1.5) @@ -867,6 +832,19 @@ def date_sub(start, days): return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) +@since(1.5) +def datediff(end, start): + """ + Returns the number of days from `start` to `end`. + + >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() + [Row(diff=32)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start))) + + @since(1.5) def add_months(start, months): """ @@ -924,27 +902,290 @@ def trunc(date, format): @since(1.5) -@ignore_unicode_prefix -def substring(str, pos, len): +def next_day(date, dayOfWeek): """ - Substring starts at `pos` and is of length `len` when str is String type or - returns the slice of byte array that starts at `pos` in byte and is of length `len` - when str is Binary type + Returns the first date which is later than the value of the date column. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s=u'ab')] + Day of the week parameter is case insensitive, and accepts: + "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + + >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d']) + >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() + [Row(date=datetime.date(2015, 8, 2))] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek)) @since(1.5) -@ignore_unicode_prefix -def substring_index(str, delim, count): +def last_day(date): """ - Returns the substring from string str before count occurrences of the delimiter delim. - If count is positive, everything the left of the final delimiter (counting from left) is + Returns the last day of the month which the given date belongs to. + + >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) + >>> df.select(last_day(df.d).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.last_day(_to_java_column(date))) + + +@since(1.5) +def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): + """ + Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + representing the timestamp of that moment in the current system time zone in the given + format. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) + + +@since(1.5) +def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): + """ + Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default) + to Unix time stamp (in seconds), using the default timezone and the default + locale, return null if fail. + + if `timestamp` is None, then it returns current timestamp. + """ + sc = SparkContext._active_spark_context + if timestamp is None: + return Column(sc._jvm.functions.unix_timestamp()) + return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format)) + + +@since(1.5) +def from_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is UTC and converts to given timezone. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) + + +@since(1.5) +def to_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is in given timezone and converts to UTC. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) + + +# ---------------------------- misc functions ---------------------------------- + +@since(1.5) +@ignore_unicode_prefix +def crc32(col): + """ + Calculates the cyclic redundancy check value (CRC32) of a binary column and + returns the value as a bigint. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + [Row(crc32=2743272264)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.crc32(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + +# ---------------------- String/Binary functions ------------------------------ + +_string_functions = { + 'ascii': 'Computes the numeric value of the first character of the string column.', + 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', + 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', + 'initcap': 'Returns a new string column by converting the first letter of each word to ' + + 'uppercase. Words are delimited by whitespace.', + 'lower': 'Converts a string column to lower case.', + 'upper': 'Converts a string column to upper case.', + 'reverse': 'Reverses the string column and returns it as a new string column.', + 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'rtrim': 'Trim the spaces from right end for the specified string value.', + 'trim': 'Trim the spaces from both ends for the specified string column.', +} + + +for _name, _doc in _string_functions.items(): + globals()[_name] = since(1.5)(_create_function(_name, _doc)) +del _name, _doc + + +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input string columns together into a single string column. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +@ignore_unicode_prefix +def concat_ws(sep, *cols): + """ + Concatenates multiple input string columns together into a single string column, + using the given separator. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + [Row(s=u'abcd-123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def decode(col, charset): + """ + Computes the first argument into a string from a binary using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.decode(_to_java_column(col), charset)) + + +@since(1.5) +def encode(col, charset): + """ + Computes the first argument into a binary from a string using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.encode(_to_java_column(col), charset)) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """ + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, + and returns the result as a string. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) + + +@ignore_unicode_prefix +@since(1.5) +def format_string(format, *cols): + """ + Formats the arguments in printf-style and returns the result as a string column. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b']) + >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() + [Row(v=u'5 hello')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def instr(str, substr): + """ + Locate the position of the first occurrence of substr column in the given string. + Returns null if either of the arguments are null. + + NOTE: The position is not zero based, but 1 based index, returns 0 if substr + could not be found in str. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(instr(df.s, 'b').alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.instr(_to_java_column(str), substr)) + + +@since(1.5) +@ignore_unicode_prefix +def substring(str, pos, len): + """ + Substring starts at `pos` and is of length `len` when str is String type or + returns the slice of byte array that starts at `pos` in byte and is of length `len` + when str is Binary type + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + [Row(s=u'ab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + + +@since(1.5) +@ignore_unicode_prefix +def substring_index(str, delim, count): + """ + Returns the substring from string str before count occurrences of the delimiter delim. + If count is positive, everything the left of the final delimiter (counting from left) is returned. If count is negative, every to the right of the final delimiter (counting from the right) is returned. substring_index performs a case-sensitive match when searching for delim. @@ -958,6 +1199,126 @@ def substring_index(str, delim, count): return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + +@since(1.5) +def locate(substr, str, pos=0): + """ + Locate the position of the first occurrence of substr in a string column, after position pos. + + NOTE: The position is not zero based, but 1 based index. returns 0 if substr + could not be found in str. + + :param substr: a string + :param str: a Column of StringType + :param pos: start position (zero based) + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(locate('b', df.s, 1).alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos)) + + +@since(1.5) +@ignore_unicode_prefix +def lpad(col, len, pad): + """ + Left-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'##abcd')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def rpad(col, len, pad): + """ + Right-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'abcd##')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def repeat(col, n): + """ + Repeats a string column n times, and returns it as a new string column. + + >>> df = sqlContext.createDataFrame([('ab',)], ['s',]) + >>> df.select(repeat(df.s, 3).alias('s')).collect() + [Row(s=u'ababab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.repeat(_to_java_column(col), n)) + + +@since(1.5) +@ignore_unicode_prefix +def split(str, pattern): + """ + Splits str around pattern (pattern is a regular expression). + + NOTE: pattern is a string represent the regular expression. + + >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',]) + >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() + [Row(s=[u'ab', u'cd'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_extract(str, pattern, idx): + """Extract a specific(idx) group identified by a java regex, from the specified string column. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + [Row(d=u'100')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_replace(str, pattern, replacement): + """Replace all substrings of the specified string value that match regexp with rep. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() + [Row(d=u'-----')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def initcap(col): @@ -970,6 +1331,186 @@ def initcap(col): return Column(sc._jvm.functions.initcap(_to_java_column(col))) +@since(1.5) +@ignore_unicode_prefix +def soundex(col): + """ + Returns the SoundEx encoding for a string + + >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) + >>> df.select(soundex(df.name).alias("soundex")).collect() + [Row(soundex=u'P362'), Row(soundex=u'U612')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.soundex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. + + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.bin(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unhex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def length(col): + """Calculates the length of a string or binary expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def translate(srcCol, matching, replace): + """A function translate any character in the `srcCol` by a character in `matching`. + The characters in `replace` is corresponding to the characters in `matching`. + The translate will happen when any character in the string matching with the character + in the `matching`. + + >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ + .alias('r')).collect() + [Row(r=u'1a2s3ae')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) + + +# ---------------------- Collection functions ------------------------------ + +@since(1.4) +def array(*cols): + """Creates a new array column. + + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. + + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.5) +def array_contains(col, value): + """ + Collection function: returns True if the array contains the given value. The collection + elements and value must be of the same type. + + :param col: name of column containing array + :param value: value to check for in array + + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + + +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ + get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.6) +def json_tuple(col, *fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -1002,19 +1543,7 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) -@since -@ignore_unicode_prefix -def soundex(col): - """ - Returns the SoundEx encoding for a string - - >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) - >>> df.select(soundex(df.name).alias("soundex")).collect() - [Row(soundex=u'P362'), Row(soundex=u'U612')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.size(_to_java_column(col))) - +# ---------------------------- User Defined Function ---------------------------------- class UserDefinedFunction(object): """ @@ -1029,14 +1558,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, @@ -1066,6 +1596,11 @@ def udf(f, returnType=StringType()): """ return UserDefinedFunction(f, returnType) +blacklist = ['map', 'since', 'ignore_unicode_prefix'] +__all__ = [k for k, v in globals().items() + if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__.sort() + def _test(): import doctest diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 04594d5a836c..9ca303a974cd 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,9 +15,9 @@ # limitations under the License. # +from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since -from pyspark.sql.column import Column, _to_seq +from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * @@ -167,6 +167,31 @@ def sum(self, *cols): [Row(sum(age)=7, sum(height)=165)] """ + @since(1.6) + def pivot(self, pivot_col, values=None): + """ + Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + There are two versions of pivot function: one that requires the caller to specify the list + of distinct values to pivot on, and one that does not. The latter is more concise but less + efficient, because Spark needs to first compute the list of distinct values internally. + + :param pivot_col: Name of the column to pivot. + :param values: List of values that will be translated to columns in the output DataFrame. + + // Compute the sum of earnings for each year by course with each course as a separate column + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() + [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + + // Or without specifying column values (less efficient) + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + """ + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) + return GroupedData(jgd, self.sql_ctx) + def _test(): import doctest @@ -182,6 +207,11 @@ def _test(): StructField('name', StringType())])) globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), + Row(course="Java", year=2012, earnings=20000), + Row(course="dotNET", year=2012, earnings=5000), + Row(course="dotNET", year=2013, earnings=48000), + Row(course="Java", year=2013, earnings=30000)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index dea8bad79e18..a3d7eca04b61 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -15,15 +15,32 @@ # limitations under the License. # +import sys + +if sys.version >= '3': + basestring = unicode = str + from py4j.java_gateway import JavaClass -from pyspark.sql import since +from pyspark import RDD, since +from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] +def to_str(value): + """ + A wrapper over str(), but convert bool values to lower case string + """ + if isinstance(value, bool): + return str(value).lower() + else: + return str(value) + + class DataFrameReader(object): """ Interface used to load a :class:`DataFrame` from external storage systems @@ -77,7 +94,7 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. """ - self._jreader = self._jreader.option(key, value) + self._jreader = self._jreader.option(key, to_str(value)) return self @since(1.4) @@ -85,21 +102,27 @@ def options(self, **options): """Adds input options for the underlying data source. """ for k in options: - self._jreader = self._jreader.option(k, options[k]) + self._jreader = self._jreader.option(k, to_str(options[k])) return self @since(1.4) def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', + ... 'python/test_support/sql/people1.json']) + >>> df.dtypes + [('age', 'bigint'), ('aka', 'string'), ('name', 'string')] """ if format is not None: self.format(format) @@ -107,30 +130,56 @@ def load(self, path=None, format=None, schema=None, **options): self.schema(schema) self.options(**options) if path is not None: - return self._df(self._jreader.load(path)) + if type(path) == list: + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + else: + return self._df(self._jreader.load(path)) else: return self._df(self._jreader.load()) @since(1.4) def json(self, path, schema=None): """ - Loads a JSON file (one object per line) and returns the result as - a :class`DataFrame`. + Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects + (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - :param path: string, path to the JSON dataset. + :param path: string represents path to the JSON dataset, + or RDD of Strings storing JSON objects. :param schema: an optional :class:`StructType` for the input schema. - >>> df = sqlContext.read.json('python/test_support/sql/people.json') - >>> df.dtypes + You can set the following JSON-specific options to deal with non-standard JSON files: + * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ + type + * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records + * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names + * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ + quotes + * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ + (e.g. 00012) + + >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + >>> df1.dtypes + [('age', 'bigint'), ('name', 'string')] + >>> rdd = sc.textFile('python/test_support/sql/people.json') + >>> df2 = sqlContext.read.json(rdd) + >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] """ if schema is not None: self.schema(schema) - return self._df(self._jreader.json(path)) + if isinstance(path, basestring): + return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + elif isinstance(path, RDD): + return self._df(self._jreader.json(path._jrdd)) + else: + raise TypeError("path can be only string or RDD") @since(1.4) def table(self, tableName): @@ -155,10 +204,26 @@ def parquet(self, *paths): """ return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + @ignore_unicode_prefix + @since(1.6) + def text(self, paths): + """Loads a text file and returns a [[DataFrame]] with a single string column named "value". + + Each line in the text file is a new row in the resulting DataFrame. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') + >>> df.collect() + [Row(value=u'hello'), Row(value=u'this')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(1.5) def orc(self, path): - """ - Loads an ORC file, returning the result as a :class:`DataFrame`. + """Loads an ORC file, returning the result as a :class:`DataFrame`. ::Note: Currently ORC support is only available together with :class:`HiveContext`. @@ -171,7 +236,7 @@ def orc(self, path): @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, - predicates=None, properties={}): + predicates=None, properties=None): """ Construct a :class:`DataFrame` representing the database table accessible via JDBC URL `url` named `table` and connection `properties`. @@ -197,6 +262,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar should be included. :return: a DataFrame """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) @@ -206,8 +273,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) @@ -392,6 +460,16 @@ def parquet(self, path, mode=None, partitionBy=None): self.partitionBy(partitionBy) self._jwrite.parquet(path) + @since(1.6) + def text(self, path): + """Saves the content of the DataFrame in a text file at the specified path. + + The DataFrame must have only one column that is of string type. + Each row becomes a new line in the output file. + """ + self._jwrite.text(path) + + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. @@ -416,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties={}): + def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -434,6 +512,8 @@ def jdbc(self, url, table, mode=None, properties={}): arbitrary string tag/value. Normally at least a "user" and "password" property should be included. """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ebd3ea8db6a4..9f5f7cfdf7a6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -31,6 +31,10 @@ import datetime import py4j +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -50,16 +54,17 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException -class UTC(datetime.tzinfo): - """UTC""" - ZERO = datetime.timedelta(0) +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) def utcoffset(self, dt): return self.ZERO - def tzname(self, dt): - return "UTC" - def dst(self, dt): return self.ZERO @@ -145,12 +150,18 @@ class PythonOnlyPoint(ExamplePoint): __UDT__ = PythonOnlyUDT() +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -161,6 +172,25 @@ def test_decimal_type(self): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + + +class SQLContextTests(ReusedPySparkTestCase): + def test_get_or_create(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) + + def test_new_session(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + sqlCtx.setConf("test_key", "a") + sqlCtx2 = sqlCtx.newSession() + sqlCtx2.setConf("test_key", "b") + self.assertEqual(sqlCtx.getConf("test_key", ""), "a") + self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") + class SQLTests(ReusedPySparkTestCase): @@ -179,6 +209,21 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_row_should_be_read_only(self): + row = Row(a=1, b=2) + self.assertEqual(1, row.a) + + def foo(): + row.a = 3 + self.assertRaises(Exception, foo) + + row2 = self.sqlCtx.range(10).first() + self.assertEqual(0, row2.id) + + def foo2(): + row2.id = 2 + self.assertRaises(Exception, foo2) + def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) @@ -366,11 +411,17 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.sqlCtx.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -486,14 +537,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -521,21 +572,21 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) @@ -629,6 +680,16 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), @@ -745,7 +806,7 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): @@ -769,7 +830,9 @@ def test_field_accessor(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) self.assertEqual("v", df.select(df.d["k"]).first()[0]) def test_infer_long_type(self): @@ -781,8 +844,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) @@ -802,13 +865,22 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_filter_with_datetime_timezone(self): + dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) + dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) + row = Row(date=dt1) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(0, df.filter(df.date == dt2).count()) + self.assertEqual(1, df.filter(df.date > dt2).count()) + self.assertEqual(0, df.filter(df.date < dt2).count()) + def test_time_with_timezone(self): day = datetime.date.today() now = datetime.datetime.now() ts = time.mktime(now.timetuple()) # class in __main__ is not serializable - from pyspark.sql.tests import UTC - utc = UTC() + from pyspark.sql.tests import UTCOffsetTimezone + utc = UTCOffsetTimezone() utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) @@ -945,7 +1017,7 @@ def test_expr(self): row = Row(a="length string", b=75) df = self.sqlCtx.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["'length(a)"]) + self.assertEqual(13, result["length(a)"]) def test_replace(self): schema = StructType([ @@ -1007,6 +1079,43 @@ def test_capture_illegalargument_exception(self): df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) + try: + df.select(sha2(df.a, 1024)).collect() + except IllegalArgumentException as e: + self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") + self.assertRegexpMatches(e.stackTrace, + "org.apache.spark.sql.functions") + + def test_with_column_with_existing_name(self): + keys = self.df.withColumn("key", self.df.key).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + + # regression test for SPARK-10417 + def test_column_iterator(self): + + def foo(): + for x in self.df.key: + break + + self.assertRaises(TypeError, foo) + + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() class HiveContextSQLTests(ReusedPySparkTestCase): @@ -1099,5 +1208,48 @@ def test_window_functions(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_without_partitionBy(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + + if __name__ == "__main__": - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6f74b7162f7c..5bc0773fa866 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -168,10 +168,12 @@ def needConversion(self): return True def toInternal(self, d): - return d and d.toordinal() - self.EPOCH_ORDINAL + if d is not None: + return d.toordinal() - self.EPOCH_ORDINAL def fromInternal(self, v): - return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + if v is not None: + return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) class TimestampType(AtomicType): @@ -467,9 +469,11 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ Construct a StructType by adding new elements to it to define the schema. The method accepts either: + a) A single parameter which is a StructField object. b) Between 2 and 4 parameters as (name, data_type, nullable (optional), - metadata(optional). The data_type parameter may be either a String or a DataType object + metadata(optional). The data_type parameter may be either a String or a + DataType object. >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True),\ @@ -535,6 +539,9 @@ def toInternal(self, obj): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -542,6 +549,9 @@ def toInternal(self, obj): return tuple(obj.get(n) for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(d.get(n) for n in self.names) else: raise ValueError("Unexpected tuple %r with StructType" % obj) @@ -1117,15 +1127,15 @@ def _verify_type(obj, dataType): return _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s" % dataType + assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) if _type is StructType: if not isinstance(obj, (tuple, list)): - raise TypeError("StructType can not accept object in type %s" % type(obj)) + raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) else: # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) + raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) if isinstance(dataType, ArrayType): for i in obj: @@ -1166,6 +1176,8 @@ class Row(tuple): >>> row = Row(name="Alice", age=11) >>> row Row(age=11, name='Alice') + >>> row['name'], row['age'] + ('Alice', 11) >>> row.name, row.age ('Alice', 11) @@ -1197,19 +1209,55 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - def asDict(self): + def asDict(self, recursive=False): """ Return as an dict + + :param recursive: turns the nested Row as dict (default: False). + + >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + True + >>> row = Row(key=1, value=Row(name='a', age=2)) + >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')} + True + >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + True """ if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__fields__, self)) + + if recursive: + def conv(obj): + if isinstance(obj, Row): + return obj.asDict(True) + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + return dict(zip(self.__fields__, (conv(o) for o in self))) + else: + return dict(zip(self.__fields__, self)) # let object acts like class def __call__(self, *args): """create new Row object""" return _create_row(self, args) + def __getitem__(self, item): + if isinstance(item, (int, slice)): + return super(Row, self).__getitem__(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return super(Row, self).__getitem__(idx) + except IndexError: + raise KeyError(item) + except ValueError: + raise ValueError(item) + def __getattr__(self, item): if item.startswith("__"): raise AttributeError(item) @@ -1223,6 +1271,11 @@ def __getattr__(self, item): except ValueError: raise AttributeError(item) + def __setattr__(self, key, value): + if key != '__fields__': + raise Exception("Row is read-only") + self.__dict__[key] = value + def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): @@ -1254,8 +1307,11 @@ def can_convert(self, obj): def convert(self, obj, gateway_client): Timestamp = JavaClass("java.sql.Timestamp", gateway_client) - return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) - + seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo + else time.mktime(obj.timetuple())) + t = Timestamp(int(seconds) * 1000) + t.setNanos(obj.microsecond * 1000) + return t # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeConverter()) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 0f795ca35b38..b0a0373372d2 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -18,13 +18,22 @@ import py4j -class AnalysisException(Exception): +class CapturedException(Exception): + def __init__(self, desc, stackTrace): + self.desc = desc + self.stackTrace = stackTrace + + def __str__(self): + return repr(self.desc) + + +class AnalysisException(CapturedException): """ Failed to analyze a SQL query plan. """ -class IllegalArgumentException(Exception): +class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. """ @@ -36,10 +45,12 @@ def deco(*a, **kw): return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: s = e.java_exception.toString() + stackTrace = '\n\t at '.join(map(lambda x: x.toString(), + e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): - raise AnalysisException(s.split(': ', 1)[1]) + raise AnalysisException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): - raise IllegalArgumentException(s.split(': ', 1)[1]) + raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise return deco @@ -60,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index c74745c726a0..57bbe340bbd4 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -17,8 +17,7 @@ import sys -from pyspark import SparkContext -from pyspark.sql import since +from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -64,7 +63,7 @@ def orderBy(*cols): Creates a :class:`WindowSpec` with the partitioning defined. """ sc = SparkContext._active_spark_context - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 944fa414b0c0..03ea0b6d33c9 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -30,7 +30,9 @@ class StatCounter(object): - def __init__(self, values=[]): + def __init__(self, values=None): + if values is None: + values = list() self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) @@ -129,6 +131,28 @@ def stdev(self): def sampleStdev(self): return sqrt(self.sampleVariance()) + def asDict(self, sample=False): + """Returns the :class:`StatCounter` members as a ``dict``. + + >>> sc.parallelize([1., 2., 3., 4.]).stats().asDict() + {'count': 4L, + 'max': 4.0, + 'mean': 2.5, + 'min': 1.0, + 'stdev': 1.2909944487358056, + 'sum': 10.0, + 'variance': 1.6666666666666667} + """ + return { + 'count': self.count(), + 'mean': self.mean(), + 'sum': self.sum(), + 'min': self.min(), + 'max': self.max(), + 'stdev': self.stdev() if sample else self.sampleStdev(), + 'variance': self.variance() if sample else self.sampleVariance() + } + def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min())) diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py index d2644a1d4ffa..66e8f8ef001e 100644 --- a/python/pyspark/streaming/__init__.py +++ b/python/pyspark/streaming/__init__.py @@ -17,5 +17,6 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream +from pyspark.streaming.listener import StreamingListener -__all__ = ['StreamingContext', 'DStream'] +__all__ = ['StreamingContext', 'DStream', 'StreamingListener'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ac5ba69e8dbb..1388b6d044e0 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -32,48 +32,6 @@ __all__ = ["StreamingContext"] -def _daemonize_callback_server(): - """ - Hack Py4J to daemonize callback server - - The thread of callback server has daemon=False, it will block the driver - from exiting if it's not shutdown. The following code replace `start()` - of CallbackServer with a new version, which set daemon=True for this - thread. - - Also, it will update the port number (0) with real port - """ - # TODO: create a patch for Py4J - import socket - import py4j.java_gateway - logger = py4j.java_gateway.logger - from py4j.java_gateway import Py4JNetworkError - from threading import Thread - - def start(self): - """Starts the CallbackServer. This method should be called by the - client instead of run().""" - self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, - 1) - try: - self.server_socket.bind((self.address, self.port)) - if not self.port: - # update port with real port - self.port = self.server_socket.getsockname()[1] - except Exception as e: - msg = 'An error occurred while trying to start the callback server: %s' % e - logger.exception(msg) - raise Py4JNetworkError(msg) - - # Maybe thread needs to be cleanup up? - self.thread = Thread(target=self.run) - self.thread.daemon = True - self.thread.start() - - py4j.java_gateway.CallbackServer.start = start - - class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext @@ -86,6 +44,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -120,10 +81,14 @@ def _ensure_initialized(cls): # start callback server # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__: - _daemonize_callback_server() - # use random port - gw._start_callback_server(0) + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport # gateway with real port gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID @@ -142,34 +107,84 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ - # TODO: support checkpoint in HDFS - if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + cls._ensure_initialized() + gw = SparkContext._gateway + + # Check whether valid checkpoint information exists in the given path + if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - cls._ensure_initialized() - gw = SparkContext._gateway - try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise - jsc = jssc.sparkContext() - conf = SparkConf(_jconf=jsc.getConf()) - sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # If there is already an active instance of Python SparkContext use it, or create a new one + if not SparkContext._active_spark_context: + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + SparkContext(conf=conf, gateway=gw, jsc=jsc) + + sc = SparkContext._active_spark_context + # update ctx in serializer - SparkContext._active_spark_context = sc cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -182,10 +197,12 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -198,9 +215,10 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ - self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -212,6 +230,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() @@ -344,3 +363,11 @@ def union(self, *dstreams): first = dstreams[0] jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) + + def addStreamingListener(self, streamingListener): + """ + Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + receiving system events related to streaming. + """ + self._jssc.addStreamingListener(self._jvm.JavaStreamingListenerWrapper( + self._jvm.PythonStreamingListenerWrapper(streamingListener))) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8dcb9645cdc6..b994a53bf2b8 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -542,33 +542,34 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None reduced = self.reduceByKey(func, numPartitions) - def reduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b - if filterFunc: - r = r.filter(filterFunc) - return r - - def invReduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) - if kv[1] is not None else kv[0]) - - jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream( + reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) else: - jinvReduceFunc = None - if slideDuration is None: - slideDuration = self._slideDuration - dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), - jreduceFunc, jinvReduceFunc, - self._ssc._jduration(windowDuration), - self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + return reduced.window(windowDuration, slideDuration).reduceByKey(func, numPartitions) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +580,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +594,13 @@ def reduceFunc(t, a, b): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) @@ -610,7 +620,10 @@ def __init__(self, prev, func): self.is_checkpointed = False self._jdstream_val = None - if (isinstance(prev, TransformedDStream) and + # Using type() to avoid folding the functions and compacting the DStreams which is not + # not strictly a object of TransformedDStream. + # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges(). + if (type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func self.func = lambda t, rdd: func(t, prev_func(t, rdd)) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cbb573f226bb..b3d190536592 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -20,7 +20,7 @@ from io import BytesIO else: from StringIO import StringIO -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int @@ -31,7 +31,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class FlumeUtils(object): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 33dd596335b4..cdf97ec73aaf 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,27 +15,31 @@ # limitations under the License. # -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \ + NoOpSerializer from pyspark.streaming import DStream from pyspark.streaming.dstream import TransformedDStream from pyspark.streaming.util import TransformFunction -__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange', + 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -52,6 +56,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ + if kafkaParams is None: + kafkaParams = dict() kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -77,8 +83,9 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @staticmethod - def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -103,13 +110,25 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, point of the stream. :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A DStream object """ + if fromOffsets is None: + fromOffsets = dict() if not isinstance(topics, list): raise TypeError("topics should be list") if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) + + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + try: helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") @@ -117,20 +136,28 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, jfromOffsets = dict([(k._jTopicAndPartition(helper), v) for (k, v) in fromOffsets.items()]) - jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + if messageHandler is None: + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + func = funcWithoutMessageHandler + jstream = helper.createDirectStreamWithoutMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) + else: + ser = AutoBatchedSerializer(PickleSerializer()) + func = funcWithMessageHandler + jstream = helper.createDirectStreamWithMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): KafkaUtils._printErrorMsg(ssc.sparkContext) raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) \ - .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + stream = DStream(jstream, ssc, ser).map(func) return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod - def createRDD(sc, kafkaParams, offsetRanges, leaders={}, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + def createRDD(sc, kafkaParams, offsetRanges, leaders=None, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -143,13 +170,25 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A RDD object """ + if leaders is None: + leaders = dict() if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") if not isinstance(offsetRanges, list): raise TypeError("offsetRanges should be list") + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) + + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + try: helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") @@ -157,15 +196,21 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] jleaders = dict([(k._jTopicAndPartition(helper), v._jBroker(helper)) for (k, v) in leaders.items()]) - jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + if messageHandler is None: + jrdd = helper.createRDDWithoutMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler) + else: + jrdd = helper.createRDDWithMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + rdd = RDD(jrdd, sc).map(funcWithMessageHandler) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): KafkaUtils._printErrorMsg(sc) raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) - return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) + return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer) @staticmethod def _printErrorMsg(sc): @@ -246,6 +291,16 @@ def __init__(self, topic, partition): def _jTopicAndPartition(self, helper): return helper.createTopicAndPartition(self._topic, self._partition) + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self._topic == other._topic + and self._partition == other._partition) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + class Broker(object): """ @@ -347,3 +402,53 @@ def _jdstream(self): dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val + + +class KafkaMessageAndMetadata(object): + """ + Kafka message and metadata information. Including topic, partition, offset and message + """ + + def __init__(self, topic, partition, offset, key, message): + """ + Python wrapper of Kafka MessageAndMetadata + :param topic: topic name of this Kafka message + :param partition: partition id of this Kafka message + :param offset: Offset of this Kafka message in the specific partition + :param key: key payload of this Kafka message, can be null if this Kafka message has no key + specified, the return data is undecoded bytearry. + :param message: actual message payload of this Kafka message, the return data is + undecoded bytearray. + """ + self.topic = topic + self.partition = partition + self.offset = offset + self._rawKey = key + self._rawMessage = message + self._keyDecoder = utf8_decoder + self._valueDecoder = utf8_decoder + + def __str__(self): + return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \ + % (self.topic, self.partition, self.offset) + + def __repr__(self): + return self.__str__() + + def __reduce__(self): + return (KafkaMessageAndMetadata, + (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage)) + + def _set_key_decoder(self, decoder): + self._keyDecoder = decoder + + def _set_value_decoder(self, decoder): + self._valueDecoder = decoder + + @property + def key(self): + return self._keyDecoder(self._rawKey) + + @property + def message(self): + return self._valueDecoder(self._rawMessage) diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index bcfe2703fecf..af72c3d6903f 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -15,7 +15,7 @@ # limitations under the License. # -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.storagelevel import StorageLevel @@ -26,7 +26,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KinesisUtils(object): diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py new file mode 100644 index 000000000000..b830797f5c0a --- /dev/null +++ b/python/pyspark/streaming/listener.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["StreamingListener"] + + +class StreamingListener(object): + + def __init__(self): + pass + + def onReceiverStarted(self, receiverStarted): + """ + Called when a receiver has been started + """ + pass + + def onReceiverError(self, receiverError): + """ + Called when a receiver has reported an error + """ + pass + + def onReceiverStopped(self, receiverStopped): + """ + Called when a receiver has been stopped + """ + pass + + def onBatchSubmitted(self, batchSubmitted): + """ + Called when a batch of jobs has been submitted for processing. + """ + pass + + def onBatchStarted(self, batchStarted): + """ + Called when processing of a batch of jobs has started. + """ + pass + + def onBatchCompleted(self, batchCompleted): + """ + Called when processing of a batch of jobs has completed. + """ + pass + + def onOutputOperationStarted(self, outputOperationStarted): + """ + Called when processing of a job of a batch has started. + """ + pass + + def onOutputOperationCompleted(self, outputOperationCompleted): + """ + Called when processing of a job of a batch has completed + """ + pass + + class Java: + implements = ["org.apache.spark.streaming.api.java.PythonStreamingListener"] diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 000000000000..1ce4093196e6 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.protocol import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cd544b2144e..4949cd68e321 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,8 +24,14 @@ import tempfile import random import struct +import shutil from functools import reduce +try: + import xmlrunner +except ImportError: + xmlrunner = None + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -40,7 +46,9 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.streaming.listener import StreamingListener class PySparkStreamingTestCase(unittest.TestCase): @@ -58,12 +66,27 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + try: + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + except: + pass def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop(False) + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + try: + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) + except: + pass def wait_for(self, result, n): start_time = time.time() @@ -380,6 +403,216 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + def test_failed_func(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("This is a special error") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func2(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream1 = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input) + + def failed_func(rdd1, rdd2): + raise ValueError("This is a special error") + + input_stream1.transformWith(failed_func, input_stream2, True).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func_with_reseting_failure(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + if i == 1: + # Make it fail in the second batch + raise ValueError("This is a special error") + else: + return i + + # We should be able to see the results of the 3rd and 4th batches even if the second batch + # fails + expected = [[0], [2], [3]] + self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3)) + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + + self.wait_for(batchInfosCompleted, 4) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + class WindowFunctionTests(PySparkStreamingTestCase): @@ -437,10 +670,22 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): inputs = [range(1, x) for x in range(101)] @@ -514,10 +759,128 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + def test_await_termination_or_timeout(self): + self._add_input_stream() + self.ssc.start() + self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) + self.ssc.stop(False) + self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) + class CheckpointTests(unittest.TestCase): - def test_get_or_create(self): + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + if SparkContext._jvm is not None: + jStreamingContextOption = \ + SparkContext._jvm.org.apache.spark.streaming.StreamingContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + + def setUp(self): + self.ssc = None + self.sc = None + self.cpd = None + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + if self.sc is not None: + self.sc.stop() + if self.cpd is not None: + shutil.rmtree(self.cpd) + + def test_transform_function_serializer_failure(self): + inputd = tempfile.mkdtemp() + self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure") + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + + # A function that cannot be serialized + def process(time, rdd): + sc.parallelize(range(1, 10)) + + ssc.textFileStream(inputd).foreachRDD(process) + return ssc + + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + try: + self.ssc.start() + except: + import traceback + failure = traceback.format_exc() + self.assertTrue( + "It appears that you are attempting to reference SparkContext" in failure) + return + + self.fail("using SparkContext in process should fail because it's not Serializable") + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -532,11 +895,16 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc - cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + # Verify that getOrCreate() calls setup() in absence of checkpoint files + self.cpd = tempfile.mkdtemp("test_streaming_cps") + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertTrue(self.setupCalled) + + self.ssc.start() def check_output(n): while not os.listdir(outputd): @@ -551,7 +919,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -567,13 +935,58 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) - ssc.stop(True, True) + + # Verify that getOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + self.sc = SparkContext(conf=SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == self.sc) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + self.sc = SparkContext(conf=SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == self.sc) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(self.cpd) # delete checkpoint directory + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertTrue(self.setupCalled) + + # Stop everything + self.ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): @@ -738,12 +1151,108 @@ def transformWithOffsetRanges(rdd): offsetRanges.append(o) return rdd - stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + # Test whether it is ok mixing KafkaTransformedDStream and TransformedDStream together, + # only the TransformedDstreams can be folded together. + stream.transform(transformWithOffsetRanges).map(lambda kv: kv[1]).count().pprint() self.ssc.start() self.wait_for(offsetRanges, 1) self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + def test_topic_and_partition_equality(self): + topic_and_partition_a = TopicAndPartition("foo", 0) + topic_and_partition_b = TopicAndPartition("foo", 0) + topic_and_partition_c = TopicAndPartition("bar", 0) + topic_and_partition_d = TopicAndPartition("foo", 1) + + self.assertEqual(topic_and_partition_a, topic_and_partition_b) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_message_handler(self): + """Test Python direct Kafka RDD MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 1, "c": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, + messageHandler=getKeyAndDoubleMessage) + self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_message_handler(self): + """Test the Python direct Kafka stream MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, + messageHandler=getKeyAndDoubleMessage) + self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds @@ -893,6 +1402,68 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -908,8 +1479,10 @@ def test_kinesis_stream_api(self): "awsAccessKey", "awsSecretKey") def test_kinesis_stream(self): - if os.environ.get('ENABLE_KINESIS_TESTS') != '1': - print("Skip test_kinesis_stream") + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) return import random @@ -950,24 +1523,34 @@ def get_output(_, rdd): traceback.print_exc() raise finally: + self.ssc.stop(False) kinesisTestUtils.deleteStream() kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) +# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because +# the artifact jars are in different directories. +def search_jar(dir, name_prefix): + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build + glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build + return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] + + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") - jars = glob.glob( - os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " - "remove all but one") % kafka_assembly_dir) + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -975,45 +1558,119 @@ def search_kafka_assembly_jar(): def search_flume_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") - jars = glob.glob( - os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test.") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " "'build/mvn package' before running this test") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] -def search_kinesis_asl_assembly_jar(): +def search_mqtt_test_jar(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") jars = glob.glob( - os.path.join(kinesis_asl_assembly_dir, - "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) if not jars: raise Exception( - ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % - kinesis_asl_assembly_dir) + "You need to build Spark with " - "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " - "or 'build/mvn -Pkinesis-asl package' before running this test") + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " - "remove all but one") % kinesis_asl_assembly_dir) + raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] +def search_kinesis_asl_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + if not jars: + return None + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - unittest.main() + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, + KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests, + StreamingListenerTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + failed = False + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + if xmlrunner: + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + if not result.wasSuccessful(): + failed = True + else: + result = unittest.TextTestRunner(verbosity=3).run(tests) + if not result.wasSuccessful(): + failed = True + sys.exit(failed) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283b..abbbf6eb9394 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,13 +37,16 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): + # Clear the failure + self.failure = None try: if self.ctx is None: self.ctx = SparkContext._active_spark_context @@ -56,14 +59,17 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: return r._jrdd - except Exception: - traceback.print_exc() + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunction(%s)" % self.func @@ -88,20 +94,29 @@ def __init__(self, ctx, serializer, gateway=None): self.serializer = serializer self.gateway = gateway or self.ctx._gateway self.gateway.jvm.PythonDStream.registerSerializer(self) + self.failure = None def dumps(self, id): + # Clear the failure + self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) - except Exception: - traceback.print_exc() + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) + except: + self.failure = traceback.format_exc() def loads(self, data): + # Clear the failure + self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) - except Exception: - traceback.print_exc() + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8bfed074c905..5bd94476597a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -35,6 +35,10 @@ import hashlib from py4j.protocol import Py4JJavaError +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -62,7 +66,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +99,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) @@ -218,6 +211,11 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + def test_itemgetter(self): from operator import itemgetter ser = CloudPickleSerializer() @@ -255,10 +253,12 @@ def __getattr__(self, item): # Regression test for SPARK-3415 def test_pickling_file_handles(self): - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) + # to be corrected with SPARK-11160 + if not xmlrunner: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -1889,6 +1889,10 @@ def test_failed_sparkcontext_creation(self): # Regression test for SPARK-1550 self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) @@ -1982,13 +1986,36 @@ def test_statcounter_array(self): self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) + stats_dict = s.asDict() + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) + + stats_sample_dict = s.asDict(sample=True) + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) + self.assertSequenceEqual( + [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) + self.assertSequenceEqual( + [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: print("NOTE: Skipping NumPy tests as it does not seem to be installed") - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 93df9002be37..42c2f8b75933 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -146,5 +146,5 @@ def process(): java_port = int(sys.stdin.readline()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("a+", 65536) + sock_file = sock.makefile("rwb", 65536) main(sock_file, sock_file) diff --git a/python/run-tests.py b/python/run-tests.py index cc560779373b..ee73eb1506ca 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,23 +31,6 @@ import Queue else: import queue as Queue -if sys.version_info >= (2, 7): - subprocess_check_output = subprocess.check_output -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -55,7 +38,7 @@ def subprocess_check_output(*popenargs, **kwargs): from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which # noqa +from sparktestsupport.shellutils import which, subprocess_check_output # noqa from sparktestsupport.modules import all_modules # noqa @@ -73,7 +56,8 @@ def print_red(text): def run_individual_python_test(test_name, pyspark_python): env = dict(os.environ) - env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python), + 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -158,7 +142,7 @@ def main(): else: log_level = logging.INFO logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") - LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') @@ -167,7 +151,8 @@ def main(): if module_name in python_modules: modules_to_test.append(python_modules[module_name]) else: - print("Error: unrecognized module %s" % module_name) + print("Error: unrecognized module '%s'. Supported modules: %s" % + (module_name, ", ".join(python_modules))) sys.exit(-1) LOGGER.info("Will test against the following Python executables: %s", python_execs) LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a8..000000000000 Binary files a/python/test_support/sql/orc_partitioned/._SUCCESS.crc and /dev/null differ diff --git a/python/test_support/sql/people1.json b/python/test_support/sql/people1.json new file mode 100644 index 000000000000..6d217da77d15 --- /dev/null +++ b/python/test_support/sql/people1.json @@ -0,0 +1,2 @@ +{"name":"Jonathan", "aka": "John"} + diff --git a/python/test_support/sql/text-test.txt b/python/test_support/sql/text-test.txt new file mode 100644 index 000000000000..ae1e76c9e93a --- /dev/null +++ b/python/test_support/sql/text-test.txt @@ -0,0 +1,2 @@ +hello +this \ No newline at end of file diff --git a/repl/pom.xml b/repl/pom.xml index a5a0f1fc2c85..154c99d23c7f 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -91,6 +91,14 @@ mockito-core test + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + org.apache.xbean + xbean-asm5-shaded + diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 8130868fe148..22749c460934 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -253,7 +253,7 @@ class SparkILoop( case xs => xs find (_.name == cmd) } } - private var fallbackMode = false + private var fallbackMode = false private def toggleFallbackMode() { val old = fallbackMode @@ -261,9 +261,9 @@ class SparkILoop( System.setProperty("spark.repl.fallback", fallbackMode.toString) echo(s""" |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would + | If you have defined classes in the repl, it would |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` + |into issues it would be good to restart the repl and turn on `:fallback` |mode as first command. """.stripMargin) } @@ -350,7 +350,7 @@ class SparkILoop( shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. + |disable/enable advanced repl changes, these fix some issues but may introduce others. |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) @@ -981,7 +981,7 @@ class SparkILoop( // which spins off a separate thread, then print the prompt and try // our best to look ready. The interlocking lazy vals tend to // inter-deadlock, so we break the cycle with a single asynchronous - // message to an actor. + // message to an rpcEndpoint. if (isAsync) { intp initialize initializedCallback() createAsyncListener() // listens for signal to run postInitialization @@ -1009,8 +1009,13 @@ class SparkILoop( val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) - .set("spark.repl.class.uri", intp.classServerUri) .setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1025,7 +1030,7 @@ class SparkILoop( val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") } catch { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed..99e1e1df33fd 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11..7fcb423575d3 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -37,7 +37,7 @@ import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.spark.annotation.DeveloperApi @@ -96,10 +96,9 @@ import org.apache.spark.annotation.DeveloperApi private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - private lazy val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) + private[repl] val outputDir = { + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + Utils.createTempDir(root = rootDir, namePrefix = "repl") } if (SPARK_DEBUG_REPL) { echo("Output directory: " + outputDir) @@ -114,8 +113,6 @@ import org.apache.spark.annotation.DeveloperApi private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - private val classServerPort = conf.getInt("spark.replClassServer.port", 0) - private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings private var printResults = true // whether to print result lines private var totalSilence = false // whether to print anything @@ -124,22 +121,6 @@ import org.apache.spark.annotation.DeveloperApi private var bindExceptions = true // whether to bind the lastException variable private var _executionWrapper = "" // code to be wrapped around all lines - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /** - * URI of the class server used to feed REPL compiled classes. - * - * @return The string representing the class server uri - */ - @DeveloperApi - def classServerUri = classServer.uri - /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -994,7 +975,6 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi def close() { reporter.flush() - classServer.stop() } /** @@ -1221,10 +1201,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669be..cbcccb11f14a 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ @@ -315,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index be31eb2eda54..44650f25f7a1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -28,26 +28,39 @@ import org.apache.spark.sql.SQLContext object Main extends Logging { val conf = new SparkConf() - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - val outputDir = Utils.createTempDir(rootDir) + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") val s = new Settings() s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) + // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. + private var hasErrors = false + + private def scalaOptionError(msg: String): Unit = { + hasErrors = true + Console.err.println(msg) + } + def main(args: Array[String]) { - if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - interp.process(s) // Repl starts and goes in loop of R.E.P.L - classServer.stop() - Option(sparkContext).map(_.stop) + val interpArguments = List( + "-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator) + ) ++ args.toList + + val settings = new Settings(scalaOptionError) + settings.processArguments(interpArguments, true) + + if (!hasErrors) { + if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + interp.process(settings) // Repl starts and goes in loop of R.E.P.L + Option(sparkContext).map(_.stop) + } } def getAddedJars: Array[String] = { @@ -66,9 +79,13 @@ object Main extends Logging { val conf = new SparkConf() .setMaster(getMaster) .setJars(jars) - .set("spark.repl.class.uri", classServer.uri) .setIfMissing("spark.app.name", "Spark shell") - logInfo("Spark class server started at " + classServer.uri) + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index bf609ff0f65f..e91139fb29f6 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -37,18 +37,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -85,7 +86,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +99,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ @@ -118,5 +119,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index bf8997998e00..63f3688c9e61 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -54,8 +54,7 @@ class ReplSuite extends SparkFunSuite { new SparkILoop(in, new PrintWriter(out)) } org.apache.spark.repl.Main.interp = interp - Main.s.processArguments(List("-classpath", classpath), true) - Main.main(Array()) // call main + Main.main(Array("-classpath", classpath)) // call main org.apache.spark.repl.Main.interp = null if (oldExecutorClasspath != null) { diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 004941d5f50a..da8f0aa1e336 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -19,25 +19,31 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils import org.apache.spark.util.ParentClassLoader -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ - /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ -class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath @@ -47,13 +53,20 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 - // Hadoop FileSystem object for our URI, if it isn't using HTTP - var fileSystem: FileSystem = { - if (Set("http", "https", "ftp").contains(uri.getScheme)) { - null - } else { - FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) - } + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) + } + + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) } override def findClass(name: String): Class[_] = { @@ -66,7 +79,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader case e: ClassNotFoundException => { val classOption = findClassLocally(name) classOption match { - case None => throw new ClassNotFoundException(name, e) + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) case Some(a) => a } } @@ -75,6 +94,11 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + Channels.newInputStream(channel) + } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) @@ -111,7 +135,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } - private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) if (fileSystem.exists(path)) { fileSystem.open(path) @@ -124,13 +149,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val pathInDirectory = name.replace('.', '/') + ".class" var inputStream: InputStream = null try { - inputStream = { - if (fileSystem != null) { - getClassFileInputStreamFromFileSystem(pathInDirectory) - } else { - getClassFileInputStreamFromHttpServer(pathInDirectory) - } - } + inputStream = fetchFn(pathInDirectory) val bytes = readAndTransformClass(name, inputStream) Some(defineClass(name, bytes, 0, bytes.length)) } catch { @@ -192,7 +211,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM4, cv) { +extends ClassVisitor(ASM5, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) @@ -202,7 +221,7 @@ extends ClassVisitor(ASM4, cv) { // field in the class to point to it, but do nothing otherwise. mv.visitCode() mv.visitVarInsn(ALOAD, 0) // load this - mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) mv.visitVarInsn(ALOAD, 0) // load this // val classType = className.replace('.', '/') // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b112..1360f09e7fa1 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -18,19 +18,29 @@ package org.apache.spark.repl import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} +import java.nio.channels.{FileChannel, ReadableByteChannel} +import java.nio.charset.StandardCharsets +import java.nio.file.{Paths, StandardOpenOption} +import java.util import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps +import com.google.common.io.Files import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.Matchers.anyString import org.mockito.Mockito._ import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils class ExecutorClassLoaderSuite @@ -41,6 +51,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,6 +65,9 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } @@ -69,7 +83,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -77,7 +91,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -85,7 +99,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -93,12 +107,32 @@ class ExecutorClassLoaderSuite test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's @@ -113,7 +147,7 @@ class ExecutorClassLoaderSuite SparkEnv.set(mockEnv) // Create an ExecutorClassLoader that's configured to load classes from the HTTP server val parentLoader = new URLClassLoader(Array.empty, null) - val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false) classLoader.httpUrlConnectionTimeoutMillis = 500 // Check that this class loader can actually load classes that exist val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() @@ -148,4 +182,27 @@ class ExecutorClassLoaderSuite failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) } + test("fetch classes using Spark's RpcEnv") { + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) + FileChannel.open(path, StandardOpenOption.READ) + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + } diff --git a/sbin/slaves.sh b/sbin/slaves.sh index cdad47ee2e59..c971aa3296b0 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -36,10 +36,11 @@ if [ $# -le 0 ]; then exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # If the slaves file is specified in the command line, # then it takes precedence over the definition in @@ -65,7 +66,7 @@ then shift fi -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$HOSTLIST" = "" ]; then if [ "$SPARK_SLAVES" = "" ]; then diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index b0361d72d3f2..d8d9d00d64eb 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -19,21 +19,12 @@ # should not be executable directly # also should not be passed any arguments, since we need original $* -# resolve links - $0 may be a softlink -this="${BASH_SOURCE:-$0}" -common_bin="$(cd -P -- "$(dirname -- "$this")" && pwd -P)" -script="$(basename -- "$this")" -this="$common_bin/$script" +# symlink and absolute path should rely on SPARK_HOME to resolve +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -# convert relative path to absolute path -config_bin="`dirname "$this"`" -script="`basename "$this"`" -config_bin="`cd "$config_bin"; pwd`" -this="$config_bin/$script" - -export SPARK_PREFIX="`dirname "$this"`"/.. -export SPARK_HOME="${SPARK_PREFIX}" -export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" -export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:${PYTHONPATH}" diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index de762acc8fa0..6ab57df40952 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|status) " # if no args specified, show usage if [ $# -le 1 ]; then @@ -37,10 +37,11 @@ if [ $# -le 1 ]; then exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # get arguments @@ -86,7 +87,7 @@ spark_rotate_log () fi } -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_IDENT_STRING" = "" ]; then export SPARK_IDENT_STRING="$USER" @@ -97,7 +98,7 @@ export SPARK_PRINT_LAUNCH_COMMAND="1" # get log directory if [ "$SPARK_LOG_DIR" = "" ]; then - export SPARK_LOG_DIR="$SPARK_HOME/logs" + export SPARK_LOG_DIR="${SPARK_HOME}/logs" fi mkdir -p "$SPARK_LOG_DIR" touch "$SPARK_LOG_DIR"/.spark_test > /dev/null 2>&1 @@ -137,7 +138,7 @@ run_command() { if [ "$SPARK_MASTER" != "" ]; then echo rsync from "$SPARK_MASTER" - rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' "$SPARK_MASTER/" "$SPARK_HOME" + rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' "$SPARK_MASTER/" "${SPARK_HOME}" fi spark_rotate_log "$log" @@ -145,12 +146,12 @@ run_command() { case "$mode" in (class) - nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & + nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & newpid="$!" ;; (submit) - nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & + nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & newpid="$!" ;; @@ -205,13 +206,13 @@ case $option in else echo $pid file is present but $command not running exit 1 - fi + fi else echo $command not running. exit 2 - fi + fi ;; - + (*) echo $usage exit 1 diff --git a/sbin/spark-daemons.sh b/sbin/spark-daemons.sh index 5d9f2bb51cae..dec2f4432df3 100755 --- a/sbin/spark-daemons.sh +++ b/sbin/spark-daemons.sh @@ -27,9 +27,10 @@ if [ $# -le 1 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/spark-daemon.sh" "$@" +exec "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/spark-daemon.sh" "$@" diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 1baf57cea09e..6217f9bf28e3 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -21,8 +21,9 @@ # Starts the master on this node. # Starts a worker on each node specified in conf/slaves -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi TACHYON_STR="" @@ -36,10 +37,10 @@ shift done # Load the Spark configuration -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # Start Master -"$sbin"/start-master.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-master.sh $TACHYON_STR # Start Workers -"$sbin"/start-slaves.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-slaves.sh $TACHYON_STR diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 7172ad15d88f..6851d99b7e8f 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -24,16 +24,11 @@ # Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" - -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" - -if [ $# != 0 ]; then - echo "Using command line arguments for setting the log directory is deprecated. Please " - echo "set the spark.history.fs.logDirectory configuration option instead." - export SPARK_HISTORY_OPTS="$SPARK_HISTORY_OPTS -Dspark.history.fs.logDirectory=$1" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" + +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 $@ diff --git a/sbin/start-master.sh b/sbin/start-master.sh index a7f5d5702fd8..9f2e14dff609 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -19,8 +19,23 @@ # Starts the master on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.master.Master" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-master.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi ORIGINAL_ARGS="$@" @@ -29,7 +44,7 @@ START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -39,9 +54,9 @@ case $1 in shift done -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_MASTER_PORT" = "" ]; then SPARK_MASTER_PORT=7077 @@ -55,12 +70,12 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then - "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "$sbin"/../tachyon/bin/tachyon format -s - "$sbin"/../tachyon/bin/tachyon-start.sh master + "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "${SPARK_HOME}"/tachyon/bin/tachyon format -s + "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master fi diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index ef1fc573d5c6..4777e1668c70 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -21,12 +21,13 @@ # Rest server to handle driver requests for Mesos cluster mode. # Only one cluster dispatcher is needed per Mesos cluster. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then SPARK_MESOS_DISPATCHER_PORT=7077 @@ -37,4 +38,4 @@ if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" +"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/start-mesos-shuffle-service.sh b/sbin/start-mesos-shuffle-service.sh index 64580762c5dc..184584567602 100755 --- a/sbin/start-mesos-shuffle-service.sh +++ b/sbin/start-mesos-shuffle-service.sh @@ -26,10 +26,11 @@ # Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle service configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh index 4fddcf7f95d4..793e165be6c7 100755 --- a/sbin/start-shuffle-service.sh +++ b/sbin/start-shuffle-service.sh @@ -24,10 +24,11 @@ # Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 4c919ff76a8f..8c268b885915 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -21,30 +21,37 @@ # # Environment Variables # -# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# SPARK_WORKER_INSTANCES The number of worker instances to run on this # slave. Default is 1. -# SPARK_WORKER_PORT The base port number for the first worker. If set, +# SPARK_WORKER_PORT The base port number for the first worker. If set, # subsequent workers will increment this number. If # unset, Spark will find a valid port number, but # with no guarantee of a predictable pattern. # SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first -# worker. Subsequent workers will increment this +# worker. Subsequent workers will increment this # number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-slave.sh [options] " + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" -if [ $# -lt 1 ]; then - echo $usage - echo Called as start-slave.sh $* + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$sbin/spark-config.sh" - -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # First argument should be the master; we need to store it aside because we may # need to insert arguments between it and the other arguments @@ -71,7 +78,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } @@ -82,4 +89,3 @@ else start_instance $(( 1 + $i )) "$@" done fi - diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 24d6268815ed..51ca81e053b7 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -19,16 +19,16 @@ # Starts a slave instance on each machine specified in the conf/slaves file. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" - +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -38,9 +38,8 @@ case $1 in shift done -. "$sbin/spark-config.sh" - -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # Find the port number for the master if [ "$SPARK_MASTER_PORT" = "" ]; then @@ -52,11 +51,11 @@ if [ "$SPARK_MASTER_IP" = "" ]; then fi if [ "$START_TACHYON" == "true" ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" + "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" # set -t so we can call sudo - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 + SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 fi # Launch the slaves -"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" +"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 5b0aeb177fff..ad7e7c5277eb 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -23,8 +23,9 @@ # Enter posix mode for bash set -o posix -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # NOTE: This exact class name is matched downstream by SparkSubmit. # Any changes need to be reflected there. @@ -39,10 +40,10 @@ function usage { pattern+="\|=======" pattern+="\|--help" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "${SPARK_HOME}"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "Thrift server options:" - "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then @@ -52,4 +53,4 @@ fi export SUBMIT_USAGE_FUNCTION=usage -exec "$FWDIR"/sbin/spark-daemon.sh submit $CLASS 1 "$@" +exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 "$@" diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 1a9abe07db84..4e476ca05cb0 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -20,23 +20,23 @@ # Stop all spark daemons. # Run this on the master node. - -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # Load the Spark configuration -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # Stop the slaves, then the master -"$sbin"/stop-slaves.sh -"$sbin"/stop-master.sh +"${SPARK_HOME}/sbin"/stop-slaves.sh +"${SPARK_HOME}/sbin"/stop-master.sh if [ "$1" == "--wait" ] then printf "Waiting for workers to shut down..." while true do - running=`$sbin/slaves.sh ps -ef | grep -v grep | grep deploy.worker.Worker` + running=`${SPARK_HOME}/sbin/slaves.sh ps -ef | grep -v grep | grep deploy.worker.Worker` if [ -z "$running" ] then printf "\nAll workers successfully shut down.\n" diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh index 6e6056359510..14e3af4be910 100755 --- a/sbin/stop-history-server.sh +++ b/sbin/stop-history-server.sh @@ -19,7 +19,8 @@ # Stops the history server on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1 +"${SPARK_HOME}/sbin/spark-daemon.sh" stop org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index 729702d92191..e57962bb354d 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -19,13 +19,14 @@ # Stops the master on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 -if [ -e "$sbin"/../tachyon/bin/tachyon ]; then - "$sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master +if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then + "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master fi diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh index cb65d95b5e52..5c0b4e051db3 100755 --- a/sbin/stop-mesos-dispatcher.sh +++ b/sbin/stop-mesos-dispatcher.sh @@ -18,10 +18,11 @@ # # Stop the Mesos Cluster dispatcher on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 diff --git a/sbin/stop-mesos-shuffle-service.sh b/sbin/stop-mesos-shuffle-service.sh index 0e965d5ec588..d23cad375e1b 100755 --- a/sbin/stop-mesos-shuffle-service.sh +++ b/sbin/stop-mesos-shuffle-service.sh @@ -19,7 +19,8 @@ # Stops the Mesos external shuffle service on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh index 4cb6891ae27f..50d69cf34e0a 100755 --- a/sbin/stop-shuffle-service.sh +++ b/sbin/stop-shuffle-service.sh @@ -19,7 +19,8 @@ # Stops the external shuffle service on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh index 3d1da5b254f2..685bcf59b33a 100755 --- a/sbin/stop-slave.sh +++ b/sbin/stop-slave.sh @@ -21,23 +21,24 @@ # # Environment variables # -# SPARK_WORKER_INSTANCES The number of worker instances that should be +# SPARK_WORKER_INSTANCES The number of worker instances that should be # running on this slave. Default is 1. # Usage: stop-slave.sh # Stops all slaves on this worker machine -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 else for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) done fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 54c9bd46803a..63956377629d 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,16 +17,17 @@ # limitations under the License. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # do before the below calls as they exec -if [ -e "$sbin"/../tachyon/bin/tachyon ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker +if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then + "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh +"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh diff --git a/sbin/stop-thriftserver.sh b/sbin/stop-thriftserver.sh index 4031a00d4a68..cf45058f882a 100755 --- a/sbin/stop-thriftserver.sh +++ b/sbin/stop-thriftserver.sh @@ -19,7 +19,8 @@ # Stops the thrift server on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1 diff --git a/scalastyle-config.xml b/scalastyle-config.xml index b5e2e882d225..6925e18737b7 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,25 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + Class\.forName + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + + ^getConfiguration$|^getTaskAttemptID$ + Instead of calling .getConfiguration() or .getTaskAttemptID() directly, + use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. + + + @@ -181,6 +215,18 @@ This file is divided into 3 sections: + + + + java,scala,3rdParty,spark + javax?\..+ + scala\..+ + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f4b1cc3a4ffe..61d6fc63554b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -53,6 +53,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.apache.spark spark-unsafe_${scala.binary.version} @@ -66,7 +70,6 @@ org.codehaus.janino janino - 2.7.8 diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index 8f1027f3164c..eea7149d0259 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; +import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 0374846d7167..3513960b4181 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -19,31 +19,30 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; -import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * - * Each tuple has two parts: [offsets] [values] + * Each tuple has three parts: [numElements] [offsets] [values] * - * In the `offsets` region, we store 4 bytes per element, represents the start address of this - * element in `values` region. We can get the length of this element by subtracting next offset. + * The `numElements` is 4 bytes storing the number of elements of this array. + * + * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the + * base address of the array) of this element in `values` region. We can get the length of this + * element by subtracting next offset. * Note that offset can by negative which means this element is null. * * In the `values` region, we store the content of elements. As we can get length info, so elements * can be variable-length. * - * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, - * then follows content. When we read in an array, we should read first 4 bytes as `numElements` - * and take the rest as content. - * * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. @@ -55,11 +54,16 @@ public class UnsafeArrayData extends ArrayData { // The number of elements in this array private int numElements; - // The size of this array's backing data, in bytes + // The size of this array's backing data, in bytes. + // The 4-bytes header of `numElements` is also included. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + private int getElementOffset(int ordinal) { - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -75,6 +79,10 @@ private void assertIndexIsValid(int ordinal) { assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; } + public Object[] array() { + throw new UnsupportedOperationException("Only supported on GenericArrayData."); + } + /** * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until * `pointTo()` has been called, since the value returned by this constructor is equivalent @@ -82,10 +90,6 @@ private void assertIndexIsValid(int ordinal) { */ public UnsafeArrayData() { } - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - @Override public int numElements() { return numElements; } @@ -94,10 +98,13 @@ public UnsafeArrayData() { } * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param sizeInBytes the size of this row's backing data, in bytes + * @param sizeInBytes the size of this array's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the number of elements from the first 4 bytes. + final int numElements = Platform.getInt(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; this.baseObject = baseObject; this.baseOffset = baseOffset; @@ -147,6 +154,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -157,7 +166,7 @@ public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return false; - return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, baseOffset + offset); } @Override @@ -165,7 +174,7 @@ public byte getByte(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, baseOffset + offset); } @Override @@ -173,7 +182,7 @@ public short getShort(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, baseOffset + offset); } @Override @@ -181,7 +190,7 @@ public int getInt(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, baseOffset + offset); } @Override @@ -189,7 +198,7 @@ public long getLong(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, baseOffset + offset); } @Override @@ -197,7 +206,7 @@ public float getFloat(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, baseOffset + offset); } @Override @@ -205,7 +214,7 @@ public double getDouble(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, baseOffset + offset); } @Override @@ -215,7 +224,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (offset < 0) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long value = Platform.getLong(baseObject, baseOffset + offset); return Decimal.apply(value, precision, scale); } else { final byte[] bytes = getBinary(ordinal); @@ -241,12 +250,7 @@ public byte[] getBinary(int ordinal) { if (offset < 0) return null; final int size = getElementSize(offset, ordinal); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size); + Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; } @@ -255,14 +259,13 @@ public CalendarInterval getInterval(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } @Override - public InternalRow getStruct(int ordinal, int numFields) { + public UnsafeRow getStruct(int ordinal, int numFields) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; @@ -273,26 +276,34 @@ public InternalRow getStruct(int ordinal, int numFields) { } @Override - public ArrayData getArray(int ordinal) { + public UnsafeArrayData getArray(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } @Override - public MapData getMap(int ordinal) { + public UnsafeMapData getMap(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + int result = 37; + for (int i = 0; i < sizeInBytes; i++) { + result = 37 * result + Platform.getByte(baseObject, baseOffset + i); + } + return result; } @Override @@ -307,27 +318,25 @@ public boolean equals(Object other) { } public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); } @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); final byte[] arrayDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - arrayDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + Platform.copyMemory( + baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 46216054ab38..651eb1ff0c56 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -17,50 +17,105 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.sql.types.ArrayData; -import org.apache.spark.sql.types.MapData; +import java.nio.ByteBuffer; + +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.unsafe.Platform; /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * - * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head + * to indicate the number of bytes of the unsafe key array. + * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ +// TODO: Use a more efficient format which doesn't depend on unsafe array. public class UnsafeMapData extends MapData { - public final UnsafeArrayData keys; - public final UnsafeArrayData values; - // The number of elements in this array - private int numElements; - // The size of this array's backing data, in bytes + private Object baseObject; + private long baseOffset; + + // The size of this map's backing data, in bytes. + // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to + // 4 + key array numBytes + value array numBytes. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } - public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + private final UnsafeArrayData keys; + private final UnsafeArrayData values; + + /** + * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeMapData() { + keys = new UnsafeArrayData(); + values = new UnsafeArrayData(); + } + + /** + * Update this UnsafeMapData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this map's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the numBytes of key array from the first 4 bytes. + final int keyArraySize = Platform.getInt(baseObject, baseOffset); + final int valueArraySize = sizeInBytes - keyArraySize - 4; + assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; + assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; + + keys.pointTo(baseObject, baseOffset + 4, keyArraySize); + values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize); + assert keys.numElements() == values.numElements(); - this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); - this.numElements = keys.numElements(); - this.keys = keys; - this.values = values; + + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; } @Override public int numElements() { - return numElements; + return keys.numElements(); } @Override - public ArrayData keyArray() { + public UnsafeArrayData keyArray() { return keys; } @Override - public ArrayData valueArray() { + public UnsafeArrayData valueArray() { return values; } + public void writeToMemory(Object target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeMapData copy() { - return new UnsafeMapData(keys.copy(), values.copy()); + UnsafeMapData mapCopy = new UnsafeMapData(); + final byte[] mapDataCopy = new byte[sizeInBytes]; + Platform.copyMemory( + baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + return mapCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java deleted file mode 100644 index b521b703389d..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.unsafe.PlatformDependent; - -public class UnsafeReaders { - - public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); - final UnsafeArrayData array = new UnsafeArrayData(); - // Skip the first 4 bytes. - array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); - return array; - } - - public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); - // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); - final int valueArraySize = numBytes - 8 - keyArraySize; - - final UnsafeArrayData keyArray = new UnsafeArrayData(); - keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); - - final UnsafeArrayData valueArray = new UnsafeArrayData(); - valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); - - return new UnsafeMapData(keyArray, valueArray); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index f4230cfaba37..b6979d0c8297 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,24 +17,62 @@ package org.apache.spark.sql.catalyst.expressions; +import java.io.Externalizable; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CalendarIntervalType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.NullType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -52,24 +90,24 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends MutableRow { +public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable { ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// public static int calculateBitSetWidthInBytes(int numFields) { - return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + return ((numFields + 63)/ 64) * 8; } /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ - public static final Set settableFieldTypes; + public static final Set mutableFieldTypes; - // DecimalType(precision <= 18) is settable + // DecimalType is also mutable static { - settableFieldTypes = Collections.unmodifiableSet( + mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( Arrays.asList(new DataType[] { NullType, @@ -87,12 +125,16 @@ public static int calculateBitSetWidthInBytes(int numFields) { public static boolean isFixedLength(DataType dt) { if (dt instanceof DecimalType) { - return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS(); + return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS(); } else { - return settableFieldTypes.contains(dt); + return mutableFieldTypes.contains(dt); } } + public static boolean isMutable(DataType dt) { + return mutableFieldTypes.contains(dt) || dt instanceof DecimalType; + } + ////////////////////////////////////////////////////////////////////////////// // Private fields and methods ////////////////////////////////////////////////////////////////////////////// @@ -106,11 +148,6 @@ public static boolean isFixedLength(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -165,7 +202,22 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + } + + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); } @Override @@ -175,7 +227,7 @@ public void setNullAt(int i) { // To preserve row equality, zero out the value when setting the column to null. // Since this row does does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + Platform.putLong(baseObject, getFieldOffset(i), 0); } @Override @@ -187,14 +239,14 @@ public void update(int ordinal, Object value) { public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + Platform.putInt(baseObject, getFieldOffset(ordinal), value); } @Override public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + Platform.putLong(baseObject, getFieldOffset(ordinal), value); } @Override @@ -204,28 +256,28 @@ public void setDouble(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + Platform.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + Platform.putShort(baseObject, getFieldOffset(ordinal), value); } @Override public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + Platform.putByte(baseObject, getFieldOffset(ordinal), value); } @Override @@ -235,29 +287,51 @@ public void setFloat(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } + /** + * Updates the decimal column. + * + * Note: In order to support update a decimal with precision > 18, CAN NOT call + * setNullAt() for this column. + */ @Override public void setDecimal(int ordinal, Decimal value, int precision) { assertIndexIsValid(ordinal); - if (value == null) { - setNullAt(ordinal); - } else { - if (precision <= Decimal.MAX_LONG_DIGITS()) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // compact format + if (value == null) { + setNullAt(ordinal); + } else { setLong(ordinal, value.toUnscaledLong()); + } + } else { + // fixed length + long cursor = getLong(ordinal) >>> 32; + assert cursor > 0 : "invalid cursor " + cursor; + // zero-out the bytes + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); + + if (value == null) { + setNullAt(ordinal); + // keep the offset for future update + Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { - // TODO(davies): support update decimal (hold a bounded space even it's null) - throw new UnsupportedOperationException(); + + final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + assert(bytes.length <= 16); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, baseObject, baseOffset + cursor, bytes.length); + setLong(ordinal, (cursor << 32) | ((long) bytes.length)); } } } - @Override - public Object genericGet(int ordinal) { - throw new UnsupportedOperationException(); - } - @Override public Object get(int ordinal, DataType dataType) { if (isNullAt(ordinal) || dataType instanceof NullType) { @@ -295,6 +369,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -309,43 +385,43 @@ public boolean isNullAt(int ordinal) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); + return Platform.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); + return Platform.getByte(baseObject, getFieldOffset(ordinal)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); + return Platform.getShort(baseObject, getFieldOffset(ordinal)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + return Platform.getInt(baseObject, getFieldOffset(ordinal)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + return Platform.getLong(baseObject, getFieldOffset(ordinal)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + return Platform.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + return Platform.getDouble(baseObject, getFieldOffset(ordinal)); } @Override @@ -359,7 +435,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { byte[] bytes = getBinary(ordinal); BigInteger bigInteger = new BigInteger(bytes); BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + return Decimal.apply(javaDecimal, precision, scale); } } @@ -368,7 +444,7 @@ public UTF8String getUTF8String(int ordinal) { if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @@ -379,13 +455,13 @@ public byte[] getBinary(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset + offset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, size ); return bytes; @@ -399,9 +475,8 @@ public CalendarInterval getInterval(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } } @@ -413,7 +488,7 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(); row.pointTo(baseObject, baseOffset + offset, numFields, size); return row; @@ -421,26 +496,30 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } @Override - public ArrayData getArray(int ordinal) { + public UnsafeArrayData getArray(int ordinal) { if (isNullAt(ordinal)) { return null; } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final int size = (int) offsetAndSize; + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } } @Override - public MapData getMap(int ordinal) { + public UnsafeMapData getMap(int ordinal) { if (isNullAt(ordinal)) { return null; } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final int size = (int) offsetAndSize; + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } } @@ -452,14 +531,14 @@ public MapData getMap(int ordinal) { public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(); final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset, rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); return rowCopy; } @@ -479,18 +558,13 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. this.baseObject = new byte[row.sizeInBytes]; } - PlatformDependent.copyMemory( - row.baseObject, - row.baseOffset, - this.baseObject, - this.baseOffset, - row.sizeInBytes - ); + Platform.copyMemory( + row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -505,19 +579,15 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; while (dataRemaining > 0) { int toTransfer = Math.min(writeBuffer.length, dataRemaining); - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + baseObject, rowReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); out.write(writeBuffer, 0, toTransfer); rowReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -545,13 +615,12 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; - PlatformDependent.copyMemory(baseObject, baseOffset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return bytes; } } @@ -561,10 +630,10 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(java.lang.Long.toHexString( - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); + build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } + build.deleteCharAt(build.length() - 1); build.append(']'); return build.toString(); } @@ -580,12 +649,72 @@ public boolean anyNull() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert (buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + + /** + * Write the bytes of var-length field into ByteBuffer + * + * Note: only work with HeapByteBuffer + */ + public void writeFieldTo(int ordinal, ByteBuffer buffer) { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; + + buffer.putInt(size); + int pos = buffer.position(); + buffer.position(pos + size); + Platform.copyMemory( baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + baseOffset + offset, + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + pos, + size); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.writeInt(this.numFields); + out.write(bytes); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.numFields = in.readInt(); + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = new byte[sizeInBytes]; + in.readFully((byte[]) baseObject); + } + + @Override + public void write(Kryo kryo, Output out) { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.writeInt(this.numFields); + out.write(bytes); + } + + @Override + public void read(Kryo kryo, Input in) { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.numFields = in.readInt(); + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = new byte[sizeInBytes]; + in.read((byte[]) baseObject); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java deleted file mode 100644 index 31928731545d..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.types.ByteArray; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A set of helper methods to write data into {@link UnsafeRow}s, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. - */ -public class UnsafeRowWriters { - - /** Writer for Decimal with precision under 18. */ - public static class CompactDecimalWriter { - - public static int getSize(Decimal input) { - return 0; - } - - public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { - target.setLong(ordinal, input.toUnscaledLong()); - return 0; - } - } - - /** Writer for Decimal with precision larger than 18. */ - public static class DecimalWriter { - - public static int getSize(Decimal input) { - // bounded size - return 16; - } - - public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { - final long offset = target.getBaseOffset() + cursor; - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - - // zero-out the bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); - - // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject(), offset, numBytes); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return 16; - } - } - - /** Writer for UTF8String. */ - public static class UTF8StringWriter { - - public static int getSize(UTF8String input) { - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes()); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) { - final long offset = target.getBaseOffset() + cursor; - final int numBytes = input.numBytes(); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - input.writeToMemory(target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** Writer for binary (byte array) type. */ - public static class BinaryWriter { - - public static int getSize(byte[] input) { - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) { - final long offset = target.getBaseOffset() + cursor; - final int numBytes = input.length; - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - ByteArray.writeToMemory(input, target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** - * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. - * - * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. - * Non-UnsafeRow struct fields are handled directly in - * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} - * by generating the Java code needed to convert them into UnsafeRow. - */ - public static class StructWriter { - public static int getSize(InternalRow input) { - int numBytes = 0; - if (input instanceof UnsafeRow) { - numBytes = ((UnsafeRow) input).getSizeInBytes(); - } else { - // This is handled directly in GenerateUnsafeProjection. - throw new UnsupportedOperationException(); - } - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { - int numBytes = 0; - final long offset = target.getBaseOffset() + cursor; - if (input instanceof UnsafeRow) { - final UnsafeRow row = (UnsafeRow) input; - numBytes = row.getSizeInBytes(); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - row.writeToMemory(target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - } else { - // This is handled directly in GenerateUnsafeProjection. - throw new UnsupportedOperationException(); - } - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** Writer for interval type. */ - public static class IntervalWriter { - - public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) { - final long offset = target.getBaseOffset() + cursor; - - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); - - // Set the fixed length portion. - target.setLong(ordinal, ((long) cursor) << 32); - return 16; - } - } - - public static class ArrayWriter { - - public static int getSize(UnsafeArrayData input) { - // we need extra 4 bytes the store the number of elements in this array. - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes() + 4; - final long offset = target.getBaseOffset() + cursor; - - // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - input.writeToMemory(target.getBaseObject(), offset + 4); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - public static class MapWriter { - - public static int getSize(UnsafeMapData input) { - // we need extra 8 bytes to store number of elements and numBytes of key array. - final int sizeInBytes = 4 + 4 + input.getSizeInBytes(); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) { - final long offset = target.getBaseOffset() + cursor; - final UnsafeArrayData keyArray = input.keys; - final UnsafeArrayData valueArray = input.values; - final int keysNumBytes = keyArray.getSizeInBytes(); - final int valuesNumBytes = valueArray.getSizeInBytes(); - final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; - - // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); - // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes of key array to the variable length portion. - keyArray.writeToMemory(target.getBaseObject(), offset + 8); - - // Write the bytes of value array to the variable length portion. - valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java deleted file mode 100644 index 0e8e405d055d..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A set of helper methods to write data into the variable length portion. - */ -public class UnsafeWriters { - public static void writeToMemory( - Object inputObject, - long inputOffset, - Object targetObject, - long targetOffset, - int numBytes) { - - // zero-out the padding bytes -// if ((numBytes & 0x07) > 0) { -// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); -// } - - // Write the UnsafeData to the target memory. - PlatformDependent.copyMemory( - inputObject, - inputOffset, - targetObject, - targetOffset, - numBytes - ); - } - - public static int getRoundedSize(int size) { - //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size); - // todo: do word alignment - return size; - } - - /** Writer for Decimal with precision larger than 18. */ - public static class DecimalWriter { - - public static int getSize(Decimal input) { - return 16; - } - - public static int write(Object targetObject, long targetOffset, Decimal input) { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - - // zero-out the bytes - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); - - // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, - targetOffset, - numBytes); - - return 16; - } - } - - /** Writer for UTF8String. */ - public static class UTF8StringWriter { - - public static int getSize(UTF8String input) { - return getRoundedSize(input.numBytes()); - } - - public static int write(Object targetObject, long targetOffset, UTF8String input) { - final int numBytes = input.numBytes(); - - // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for binary (byte array) type. */ - public static class BinaryWriter { - - public static int getSize(byte[] input) { - return getRoundedSize(input.length); - } - - public static int write(Object targetObject, long targetOffset, byte[] input) { - final int numBytes = input.length; - - // Write the bytes to the variable length portion. - writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for UnsafeRow. */ - public static class StructWriter { - - public static int getSize(UnsafeRow input) { - return getRoundedSize(input.getSizeInBytes()); - } - - public static int write(Object targetObject, long targetOffset, UnsafeRow input) { - final int numBytes = input.getSizeInBytes(); - - // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for interval type. */ - public static class IntervalWriter { - - public static int getSize(UnsafeRow input) { - return 16; - } - - public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); - - return 16; - } - } - - /** Writer for UnsafeArrayData. */ - public static class ArrayWriter { - - public static int getSize(UnsafeArrayData input) { - // we need extra 4 bytes the store the number of elements in this array. - return getRoundedSize(input.getSizeInBytes() + 4); - } - - public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes(); - - // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); - - // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset + 4, numBytes); - - return getRoundedSize(numBytes + 4); - } - } - - public static class MapWriter { - - public static int getSize(UnsafeMapData input) { - // we need extra 8 bytes to store number of elements and numBytes of key array. - return getRoundedSize(4 + 4 + input.getSizeInBytes()); - } - - public static int write(Object targetObject, long targetOffset, UnsafeMapData input) { - final UnsafeArrayData keyArray = input.keys; - final UnsafeArrayData valueArray = input.values; - final int keysNumBytes = keyArray.getSizeInBytes(); - final int valuesNumBytes = valueArray.getSizeInBytes(); - final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; - - // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); - // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); - - // Write the bytes of key array to the variable length portion. - writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), - targetObject, targetOffset + 8, keysNumBytes); - - // Write the bytes of value array to the variable length portion. - writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(), - targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes); - - return getRoundedSize(numBytes); - } - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java new file mode 100644 index 000000000000..d26b1b187c27 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; + +/** + * A helper class to manage the row buffer when construct unsafe rows. + */ +public class BufferHolder { + public byte[] buffer; + public int cursor = Platform.BYTE_ARRAY_OFFSET; + + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { + final int length = totalSize() + neededSize; + if (buffer.length < length) { + // This will not happen frequently, because the buffer is re-used. + final byte[] tmp = new byte[length * 2]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } + } + } + + public void grow(int neededSize) { + grow(neededSize, null); + } + + public void reset() { + cursor = Platform.BYTE_ARRAY_OFFSET; + } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } + + public int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java new file mode 100644 index 000000000000..7dd932d1981b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write data into global row buffer using `UnsafeArrayData` format, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeArrayWriter { + + private BufferHolder holder; + + // The offset of the global buffer where we start to write this array. + private int startingOffset; + + public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { + // We need 4 bytes to store numElements and 4 bytes each element to store offset. + final int fixedSize = 4 + 4 * numElements; + + this.holder = holder; + this.startingOffset = holder.cursor; + + holder.grow(fixedSize); + Platform.putInt(holder.buffer, holder.cursor, numElements); + holder.cursor += fixedSize; + + // Grows the global buffer ahead for fixed size data. + holder.grow(fixedElementSize * numElements); + } + + private long getElementOffset(int ordinal) { + return startingOffset + 4 + 4 * ordinal; + } + + public void setNullAt(int ordinal) { + final int relativeOffset = holder.cursor - startingOffset; + // Writes negative offset value to represent null element. + Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset); + } + + public void setOffset(int ordinal) { + final int relativeOffset = holder.cursor - startingOffset; + Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); + } + + public void write(int ordinal, boolean value) { + Platform.putBoolean(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, byte value) { + Platform.putByte(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, short value) { + Platform.putShort(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 2; + } + + public void write(int ordinal, int value) { + Platform.putInt(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; + } + } else { + setNullAt(ordinal); + } + } + + public void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + + // grow the global buffer before writing data. + holder.grow(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += numBytes; + } + + public void write(int ordinal, byte[] input) { + // grow the global buffer before writing data. + holder.grow(input.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += input.length; + } + + public void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + holder.grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(holder.buffer, holder.cursor, input.months); + Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += 16; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java new file mode 100644 index 000000000000..e227c0dec974 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write data into global row buffer using `UnsafeRow` format, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeRowWriter { + + private BufferHolder holder; + // The offset of the global buffer where we start to write this row. + private int startingOffset; + private int nullBitsSize; + private UnsafeRow row; + + public void initialize(BufferHolder holder, int numFields) { + this.holder = holder; + this.startingOffset = holder.cursor; + this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + + // grow the global buffer to make sure it has enough space to write fixed-length data. + final int fixedSize = nullBitsSize + 8 * numFields; + holder.grow(fixedSize, row); + holder.cursor += fixedSize; + + // zero-out the null bits region + for (int i = 0; i < nullBitsSize; i += 8) { + Platform.putLong(holder.buffer, startingOffset + i, 0L); + } + } + + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + + private void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + } + } + + public BufferHolder holder() { return holder; } + + public boolean isNullAt(int ordinal) { + return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + } + + public void setNullAt(int ordinal) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + } + + public long getFieldOffset(int ordinal) { + return startingOffset + nullBitsSize + 8 * ordinal; + } + + public void setOffsetAndSize(int ordinal, long size) { + setOffsetAndSize(ordinal, holder.cursor, size); + } + + public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + final long relativeOffset = currentCursor - startingOffset; + final long fieldOffset = getFieldOffset(ordinal); + final long offsetAndSize = (relativeOffset << 32) | size; + + Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + } + + // Do word alignment for this row and grow the row buffer if needed. + // todo: remove this after we make unsafe array data word align. + public void alignToWords(int numBytes) { + final int remainder = numBytes & 0x07; + + if (remainder > 0) { + final int paddingBytes = 8 - remainder; + holder.grow(paddingBytes, row); + + for (int i = 0; i < paddingBytes; i++) { + Platform.putByte(holder.buffer, holder.cursor, (byte) 0); + holder.cursor++; + } + } + } + + public void write(int ordinal, boolean value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putBoolean(holder.buffer, offset, value); + } + + public void write(int ordinal, byte value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putByte(holder.buffer, offset, value); + } + + public void write(int ordinal, short value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putShort(holder.buffer, offset, value); + } + + public void write(int ordinal, int value) { + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putInt(holder.buffer, offset, value); + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putFloat(holder.buffer, offset, value); + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } + } else { + // grow the global buffer before writing data. + holder.grow(16, row); + + // zero-out the bytes + Platform.putLong(holder.buffer, holder.cursor, 0L); + Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + + // Make sure Decimal object has the same scale as DecimalType. + // Note that we may pass in null Decimal object to set null for it. + if (input == null || !input.changePrecision(precision, scale)) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // keep the offset for future update + setOffsetAndSize(ordinal, 0L); + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffsetAndSize(ordinal, bytes.length); + } + + // move the cursor forward. + holder.cursor += 16; + } + } + + public void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize, row); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize, row); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, + holder.buffer, holder.cursor, numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + holder.grow(16, row); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(holder.buffer, holder.cursor, input.months); + Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + holder.cursor += 16; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 5e4c6232c947..352002b3499a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,12 +26,12 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.AbstractScalaRowIterator; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } @@ -67,7 +67,6 @@ public UnsafeExternalRowSorter( final TaskContext taskContext = TaskContext.get(); sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), - sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, new RowComparator(ordering, schema.length()), @@ -97,17 +96,19 @@ void insertRow(UnsafeRow row) throws IOException { ); numRowsInserted++; if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { - spill(); + sorter.spill(); } } - @VisibleForTesting - void spill() throws IOException { - sorter.spill(); + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); } private void cleanupResources() { - sorter.freeMemory(); + sorter.cleanupResources(); } @VisibleForTesting @@ -150,7 +151,7 @@ public UnsafeRow next() { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: - PlatformDependent.throwException(e); + Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); }; @@ -169,13 +170,6 @@ public Iterator sort(Iterator inputIterator) throws IOExce return sort(); } - /** - * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. - */ - public static boolean supportsSchema(StructType schema) { - return UnsafeProjection.canSupport(schema); - } - private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b3..1e4e5ede8cc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala new file mode 100644 index 000000000000..bb0fdc4c3d83 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.reflect.Modifier + +import scala.annotation.implicitNotFound +import scala.reflect.{ClassTag, classTag} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * == Scala == + * Encoders are generally created automatically through implicits from a `SQLContext`. + * + * {{{ + * import sqlContext.implicits._ + * + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * }}} + * + * == Java == + * Encoders are specified by calling static methods on [[Encoders]]. + * + * {{{ + * List data = Arrays.asList("abc", "abc", "xyz"); + * Dataset ds = context.createDataset(data, Encoders.STRING()); + * }}} + * + * Encoders can be composed into tuples: + * + * {{{ + * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + * List> data2 = Arrays.asList(new scala.Tuple2(1, "a"); + * Dataset> ds2 = context.createDataset(data2, encoder2); + * }}} + * + * Or constructed from Java Beans: + * + * {{{ + * Encoders.bean(MyClass.class); + * }}} + * + * == Implementation == + * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard + * against concurrent access if they reuse internal buffers to improve performance. + * + * @since 1.6.0 + */ +@Experimental +@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + + "(Int, String, etc) and Product types (case classes) are supported by importing " + + "sqlContext.implicits._ Support for serializing other types will be added in future " + + "releases.") +trait Encoder[T] extends Serializable { + + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] +} + +/** + * :: Experimental :: + * Methods for creating an [[Encoder]]. + * + * @since 1.6.0 + */ +@Experimental +object Encoders { + + /** + * An encoder for nullable boolean type. + * @since 1.6.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * @since 1.6.0 + */ + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * @since 1.6.0 + */ + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * @since 1.6.0 + */ + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * @since 1.6.0 + */ + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * @since 1.6.0 + */ + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * @since 1.6.0 + */ + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * @since 1.6.0 + */ + def STRING: Encoder[java.lang.String] = ExpressionEncoder() + + /** + * An encoder for nullable decimal type. + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } + + /** + * An encoder for 2-ary tuples. + * @since 1.6.0 + */ + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) + } + + /** + * An encoder for 3-ary tuples. + * @since 1.6.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) + } + + /** + * An encoder for 4-ary tuples. + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) + } + + /** + * An encoder for 5-ary tuples. + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 91449479fa53..b14c66cc5ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ +import scala.util.hashing.MurmurHash3 + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -149,7 +152,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def apply(i: Int): Any = get(i) @@ -174,7 +177,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def get(i: Int): Any @@ -188,7 +191,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -196,7 +199,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte = getAs[Byte](i) + def getByte(i: Int): Byte = getAnyValAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -204,7 +207,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short = getAs[Short](i) + def getShort(i: Int): Short = getAnyValAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -212,7 +215,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int = getAs[Int](i) + def getInt(i: Int): Int = getAnyValAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -220,7 +223,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long = getAs[Long](i) + def getLong(i: Int): Long = getAnyValAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -229,7 +232,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float = getAs[Float](i) + def getFloat(i: Int): Float = getAnyValAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -237,13 +240,12 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double = getAs[Double](i) + def getDouble(i: Int): Double = getAnyValAs[Double](i) /** * Returns the value at position i as a String object. * * @throws ClassCastException when data type does not match. - * @throws NullPointerException when value is null. */ def getString(i: Int): String = getAs[String](i) @@ -280,9 +282,8 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getList[T](i: Int): java.util.List[T] = { - scala.collection.JavaConversions.seqAsJavaList(getSeq[T](i)) - } + def getList[T](i: Int): java.util.List[T] = + getSeq[T](i).asJava /** * Returns the value at position i of map type as a Scala Map. @@ -296,19 +297,28 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getJavaMap[K, V](i: Int): java.util.Map[K, V] = { - scala.collection.JavaConversions.mapAsJavaMap(getMap[K, V](i)) - } + def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + getMap[K, V](i).asJava /** * Returns the value at position i of struct type as an [[Row]] object. * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = getAs[Row](i) + def getStruct(i: Int): Row = { + // Product and Row both are recoginized as StructType in a Row + val t = get(i) + if (t.isInstanceOf[Product]) { + Row.fromTuple(t.asInstanceOf[Product]) + } else { + t.asInstanceOf[Row] + } + } /** * Returns the value at position i. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws ClassCastException when data type does not match. */ @@ -316,6 +326,8 @@ trait Row extends Serializable { /** * Returns the value of a given fieldName. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -335,6 +347,8 @@ trait Row extends Serializable { /** * Returns a Map(name -> value) for the requested fieldNames + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -364,31 +378,10 @@ trait Row extends Serializable { false } - /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on external row with internal row. - */ - protected def canEqual(other: Row) = { - // Note that `Row` is not only the interface of external row but also the parent - // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent - // call `equals` on external row with internal row. - // `InternalRow` overrides canEqual, and these two canEquals together makes sure that - // equality check between external Row and InternalRow will always fail. - // In the future, InternalRow should not extend Row. In that case, we can remove these - // canEqual methods. - !other.isInstanceOf[InternalRow] - } - override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) return false val other = o.asInstanceOf[Row] - if (!canEqual(other)) { - throw new UnsupportedOperationException( - "cannot check equality between external and internal rows") - } - if (other eq null) return false if (length != other.length) { @@ -417,6 +410,10 @@ trait Row extends Serializable { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } + case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] => + if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { + return false + } case _ => if (o1 != o2) { return false } @@ -427,6 +424,18 @@ trait Row extends Serializable { true } + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + /* ---------------------- utility methods for Scala ---------------------- */ /** @@ -454,4 +463,15 @@ trait Row extends Serializable { * start, end, and separator strings. */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + private def getAnyValAs[T <: AnyVal](i: Int): T = + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + else getAs[T](i) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index d494ae7b71d1..bdc52c08acb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { @@ -78,7 +78,7 @@ private[sql] abstract class AbstractSparkSQLParser } class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { + case class DecimalLit(chars: String) extends Token { override def toString: String = chars } @@ -102,11 +102,16 @@ class SqlLexical extends StdLexical { } override lazy val token: Parser[Token] = - ( identChar ~ (identChar | digit).* ^^ - { case first ~ rest => processIdent((first :: rest).mkString) } + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ { case chars => StringLit(chars mkString "") } @@ -123,6 +128,11 @@ class SqlLexical extends StdLexical { override def identChar: Parser[Elem] = letter | elem('_') + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } + override def whitespace: Parser[Any] = ( whitespaceChar | '/' ~ '*' ~ comment diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b07b37d..2c7c58e66b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -32,4 +32,5 @@ object EmptyConf extends CatalystConf { } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index c666864e43ab..2ec0ff53c89c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -27,7 +27,7 @@ import scala.language.existentials import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } @@ -317,18 +322,26 @@ object CatalystTypeConverters { private class DecimalConverter(dataType: DecimalType) extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { - override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { - case d: BigDecimal => Decimal(d) - case d: JavaBigDecimal => Decimal(d) - case d: Decimal => d + override def toCatalystImpl(scalaValue: Any): Decimal = { + val decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + if (decimal.changePrecision(dataType.precision, dataType.scale)) { + decimal + } else { + null + } + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = { + if (catalystValue == null) null + else catalystValue.toJavaBigDecimal } - override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } - private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) - private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue @@ -413,8 +426,8 @@ object CatalystTypeConverters { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) case t: Timestamp => TimestampConverter.toCatalyst(t) - case d: BigDecimal => BigDecimalConverter.toCatalyst(d) - case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) + case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) + case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7656d054dc36..eba95c5c8b90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -17,157 +17,60 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -// todo: make InternalRow just extends SpecializedGetters, remove generic getter -abstract class InternalRow extends GenericSpecializedGetters with Serializable { +abstract class InternalRow extends SpecializedGetters with Serializable { def numFields: Int // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } + def anyNull: Boolean /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { /** - * This method can be used to construct a [[Row]] with the given values. + * This method can be used to construct a [[InternalRow]] with the given values. */ def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray) /** - * This method can be used to construct a [[Row]] from a [[Seq]] of values. + * This method can be used to construct a [[InternalRow]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray) - /** Returns an empty row. */ + /** Returns an empty [[InternalRow]]. */ val empty = apply() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4..c8ee87e8819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,29 +17,36 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, Map => JMap, List => JList} import scala.language.existentials import com.google.common.reflect.TypeToken + import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.unsafe.types.UTF8String + /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -53,12 +60,13 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -88,15 +96,14 @@ private [sql] object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") val fields = properties.map { property => @@ -108,11 +115,294 @@ private [sql] object JavaTypeInference { } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of java bean `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor(beanClass: Class[_]): Expression = { + constructorFor(TypeToken.of(beanClass), None) + } + + private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => constructorFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + }.toMap + + val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + /** + * Returns expressions for extracting all the fields from the given type. + */ + def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + } + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb..e21d3c05464b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2442341da106..ecff8605706d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * A default version of ScalaReflection that uses the runtime universe. @@ -33,10 +34,572 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + /** + * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping + * to a native type, an ObjectType is returned. Special handling is also used for Arrays including + * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers + */ + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className = getClassNameFromType(tpe) + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } + } + + /** + * Given a type `T` this function constructs and ObjectType that holds a class of type + * Array[T]. Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + val cls = tpe match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case other => + // There is probably a better way to do this, but I couldn't find it... + val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls + java.lang.reflect.Array.newInstance(elementType, 1).getClass + + } + ObjectType(cls) + } + + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + + /** + * Returns an expression that can be used to construct an object of type `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling resolve/bind with a new schema. + */ + def constructorFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + constructorFor(tpe, None, walkedTypePath) + } + + private def constructorFor( + tpe: `Type`, + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } + } + + /** + * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, thus we only need to handle leaf type. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p), newTypePath), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p), newTypePath), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p), walkedTypePath), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p), walkedTypePath), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + + val arguments = params.head.zipWithIndex.map { case (p, i) => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = schemaFor(fieldType).dataType + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + // For tuples, we based grab the inner fields by ordinal instead of name. + if (cls.getName startsWith "scala.Tuple") { + constructorFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) + } else { + constructorFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) + } + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + } + } + + /** + * Returns expressions for extracting all the fields from the given type. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ + def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + extractorFor(inputObject, tpe, walkedTypePath) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } + } + + /** Helper for extracting internal fields from a case class. */ + private def extractorFor( + inputObject: Expression, + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className = getClassNameFromType(optType) + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + extractorFor(unwrapped, optType, newPath)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateNamedStruct(params.head.flatMap { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) + } + } + } } /** - * Support for generating catalyst schemas for scala objects. + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait ScalaReflection { /** The universe we work in (runtime or macro) */ @@ -60,8 +623,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -73,11 +635,11 @@ trait ScalaReflection { * * @see SPARK-5281 */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -91,7 +653,6 @@ trait ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) - // Need to decide if we actually need a special type here. case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -154,38 +715,29 @@ trait ScalaReflection { } } - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) } - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { + /** Returns the full class name for a type. */ + def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c957..2a132d8b82be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -22,8 +22,10 @@ import scala.language.implicitConversions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -37,9 +39,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +50,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { @@ -170,7 +172,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ { + ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ { case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) } | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } @@ -218,7 +220,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) protected lazy val andExpression: Parser[Expression] = - comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) + notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) + + protected lazy val notExpression: Parser[Expression] = + NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) } protected lazy val comparisonExpression: Parser[Expression] = ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } @@ -246,7 +251,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } - | NOT ~> termExpression ^^ {e => Not(e)} | termExpression ) @@ -269,7 +273,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid expression $udfName(*)") } @@ -278,14 +282,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) + case "count" => + aggregate.Count(exprs).toAggregateExpression(isDistinct = true) case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) + AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate $udfName") } @@ -293,7 +297,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) + AggregateExpression( + HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), + mode = Complete, + isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate($s) $udfName") } @@ -319,7 +326,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal.create(s, StringType) } + | stringLit ^^ { case s => Literal.create(s, StringType) } | intervalLiteral | NULL ^^^ Literal.create(null, NullType) ) @@ -331,14 +338,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { - case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) - } + | sign.? ~ unsignedFloat ^^ + { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) protected lazy val unsignedFloat: Parser[String] = ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) ) protected lazy val sign: Parser[String] = ("+" | "-") @@ -346,13 +352,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val integral: Parser[String] = sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - private def intervalUnit(unitName: String) = - acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} + private def intervalUnit(unitName: String) = acceptIf { + case lexical.Identifier(str) => + val normalized = lexical.normalizeKeyword(str) + normalized == unitName || normalized == unitName + "s" + case _ => false + } {_ => "wrong interval unit"} protected lazy val month: Parser[Int] = integral <~ intervalUnit("month") ^^ { case num => num.toInt } @@ -393,21 +398,53 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } + private def intervalKeyword(keyword: String) = acceptIf { + case lexical.Identifier(str) => + lexical.normalizeKeyword(str) == keyword + case _ => false + } {_ => "wrong interval keyword"} + protected lazy val intervalLiteral: Parser[Literal] = - INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { - case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ + ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ + intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromYearMonthString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ + intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromDayTimeString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("year", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("month", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("day", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("hour", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("minute", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("second", s)) + } + | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ + millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) + if (!Seq(year, month, week, day, hour, minute, second, + millisecond, microsecond).exists(_.isDefined)) { + throw new AnalysisException( + "at least one time unit should be given for interval literal") } + val months = Seq(year, month).map(_.getOrElse(0)).sum + val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) + .map(_.getOrElse(0L)).sum + Literal(new CalendarInterval(months, microseconds)) + } + ) private def toNarrowestIntegerType(value: String): Any = { val bigIntValue = BigDecimal(value) @@ -432,9 +469,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } + | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))} | primary - ) + ) protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index aebcdeb9d070..4d4e4ded9947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst /** * Identifies a `table` in `database`. If `database` is not defined, the current database is used. */ -private[sql] case class TableIdentifier(table: String, database: Option[String] = None) { - def withDatabase(database: String): TableIdentifier = this.copy(database = Some(database)) +private[sql] case class TableIdentifier(table: String, database: Option[String]) { + def this(table: String) = this(table, None) - def toSeq: Seq[String] = database.toSeq :+ table + override def toString: String = quotedString - override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + def quotedString: String = database.map(db => s"`$db`.`$table`").getOrElse(s"`$table`") - def unquotedString: String = toSeq.mkString(".") + def unquotedString: String = database.map(db => s"$db.$table").getOrElse(table) +} + +private[sql] object TableIdentifier { + def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f5daba1543da..64dd83a91571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -65,25 +65,33 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: + ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowOrder :: + ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: - UnresolvedHavingClauseAttributes :: - RemoveEvaluationFromSort :: + ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic) + PullOutNondeterministic, + ComputeCurrentTime), + Batch("UDF", Once, + HandleNullInputsForUDF), + Batch("Cleanup", fixedPoint, + CleanupAliases) ) /** @@ -91,7 +99,7 @@ class Analyzer( */ object CTESubstitution extends Rule[LogicalPlan] { // TODO allow subquery to define CTE - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case With(child, relations) => substituteCTE(child, relations) case other => other } @@ -104,7 +112,7 @@ class Analyzer( // here use the CTE definition first, check table name only and ignore database name // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info case u : UnresolvedRelation => - val substituted = cteRelations.get(u.tableIdentifier.last).map { relation => + val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => val withAlias = u.alias.map(Subquery(_, relation)) withAlias.getOrElse(relation) } @@ -117,18 +125,16 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { - case plan => plan.transformExpressions { + case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = - windowDefinitions - .get(windowName) - .getOrElse(failAnalysis(errorMessage)) + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } @@ -140,34 +146,35 @@ class Analyzer( */ object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { - // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to transform down the whole tree. exprs.zipWithIndex.map { - case (u @ UnresolvedAlias(child), i) => - child match { - case _: UnresolvedAttribute => u - case ne: NamedExpression => ne - case g: GetStructField => Alias(g, g.field.name)() - case g: GetArrayStructFields => Alias(g, g.field.name)() - case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) - case e if !e.resolved => u - case other => Alias(other, s"_c$i")() + case (expr, i) => + expr transform { + case u @ UnresolvedAlias(child) => child match { + case ne: NamedExpression => ne + case e if !e.resolved => u + case g: Generator => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() + case other => Alias(other, s"_c$i")() + } } - case (other, _) => other - } + }.asInstanceOf[Seq[NamedExpression]] } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Aggregate(groups, aggs, child) - if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => + private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = + exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics - if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) - case Project(projectList, child) - if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } } @@ -199,7 +206,7 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. case a: Cube => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) @@ -207,45 +214,83 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } + + val nonNullBitmask = x.bitmasks.reduce(_ & _) - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => + if ((nonNullBitmask & 1 << idx) == 0) { + (a -> a.toAttribute.withNullability(true)) + } else { + (a -> a.toAttribute) + } + }.toMap - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val aggregations: Seq[NamedExpression] = x.aggregations.map { + // If an expression is an aggregate (contains a AggregateExpression) then we dont change + // it so that the aggregation is computed on the unmodified value of its argument + // expressions. + case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr + // If not then its a grouping expression and we need to use the modified (with nulls from + // Expand) value of the expression. + case expr => expr.transformDown { + case e => + groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) + }.asInstanceOf[NamedExpression] } + val child = Project(x.child.output ++ groupByAliases, x.child) + val groupByAttributes = groupByAliases.map(attributeMap(_)) + Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ VirtualColumn.groupingIdAttribute, + aggregations, + Expand(x.bitmasks, groupByAttributes, gid, child)) + } + } + + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) } } @@ -258,15 +303,21 @@ class Analyzer( catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"no such table ${u.tableName}") + u.failAnalysis(s"Table not found: ${u.tableName}") } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => - getTable(u) + try { + getTable(u) + } catch { + case _: AnalysisException if u.tableIdentifier.database.isDefined => + // delay the exception into CheckAnalysis, then it could be resolved as data source. + u + } } } @@ -275,51 +326,67 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + /** + * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree + * rooted at each expression. + */ + def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { + exprs.flatMap { + case s: Star => s.expand(child, resolver) + case e => + e.transformDown { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + } :: Nil + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) - case o => o :: Nil - } - UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + val newChildren = expandStarExpressions(args, child) + UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil + case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => + val newChildren = expandStarExpressions(args, child) + Alias(child = f.copy(children = newChildren), name)() :: Nil case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) + case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { - case s: Star => s.expand(t.child.output, resolver) + case s: Star => s.expand(t.child, resolver) case o => o :: Nil } ) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy( - aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output, resolver) - case o => o :: Nil - } - ) + val expanded = expandStarExpressions(a.aggregateExpressions, a.child) + .map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. case j @ Join(left, right, _, _) if !j.selfJoinResolved => @@ -379,15 +446,29 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) - } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -409,15 +490,10 @@ class Analyzer( /** * Returns true if `exprs` contains a [[Star]]. */ - protected def containsStar(exprs: Seq[Expression]): Boolean = + def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def trimUnresolvedAlias(ne: NamedExpression) = ne match { - case UnresolvedAlias(child) => child - case other => other - } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -427,7 +503,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + plan.resolve(nameParts, resolver).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -445,7 +521,7 @@ class Analyzer( * remove these attributes after sorting. */ object ResolveSortReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) @@ -460,37 +536,6 @@ class Analyzer( logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } - case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) - if !s.resolved && a.resolved => - // A small hack to create an object that will allow us to resolve any references that - // refer to named expressions that are present in the grouping expressions. - val groupingRelation = LocalRelation( - grouping.collect { case ne: NamedExpression => ne.toAttribute } - ) - - // Find sort attributes that are projected away so we can temporarily add them back in. - val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation) - - // Find aggregate expressions and evaluate them early, since they can't be evaluated in a - // Sort. - val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { - case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => - val aliased = Alias(aggOrdering.child, "_aggOrdering")() - (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) - - case other => (other, None) - }.unzip - - val missing = missingAttr ++ aliasedAggregateList.flatten - - if (missing.nonEmpty) { - // Add missing grouping exprs and then project them away after the sort. - Project(a.output, - Sort(withAggsRemoved, global, - Aggregate(grouping, aggs ++ missing, child))) - } else { - s // Nothing we can do here. Return original plan. - } } /** @@ -505,7 +550,7 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. - val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved)) + val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. val missingInProject = requiredAttributes -- child.output @@ -520,27 +565,25 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { + case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } @@ -552,42 +595,121 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val windowedAggExprs = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + }.isDefined) } } /** - * This rule finds expressions in HAVING clause filters that depend on - * unresolved attributes. It pushes these expressions down to the underlying - * aggregates and then projects them away above the filter. + * This rule finds aggregate expressions that are not in an aggregate operator. For example, + * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the + * underlying aggregate operator and then projected away after the original operator. */ - object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if aggregate.resolved && containsAggregate(havingCondition) => + object ResolveAggregateFunctions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case filter @ Filter(havingCondition, + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved => + + // Try resolving the condition of the filter as though it is in the aggregate clause + val aggregatedCondition = + Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } + + case sort @ Sort(sortOrder, global, aggregate: Aggregate) + if aggregate.resolved => + + // Try resolving the ordering as though it is in the aggregate clause. + try { + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } - val evaluatedCondition = Alias(havingCondition, "havingCondition")() - val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } + } - Project(aggregate.output, - Filter(evaluatedCondition.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == finalSortOrders) { + sort + } else { + Project(aggregate.output, + Sort(finalSortOrders, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => sort + } } protected def containsAggregate(condition: Expression): Boolean = { - condition - .collect { case ae: AggregateExpression => ae } - .nonEmpty + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } @@ -602,7 +724,9 @@ class Analyzer( * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case g: Generate if ResolveReferences.containsStar(g.generator.children) => + failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -655,7 +779,7 @@ class Analyzer( /** * Construct the output attributes for a [[Generator]], given a list of names. If the list of - * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + * names is empty names are assigned from field names in generator. */ private def makeGeneratorOutput( generator: Generator, @@ -664,13 +788,12 @@ class Analyzer( if (names.length == elementTypes.length) { names.zip(elementTypes).map { - case (name, (t, nullable)) => + case (name, (t, nullable, _)) => AttributeReference(name, t, nullable)() } } else if (names.isEmpty) { - elementTypes.zipWithIndex.map { - // keep the default column names as Hive does _c0, _c1, _cN - case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + elementTypes.map { + case (t, nullable, name) => AttributeReference(name, t, nullable)() } } else { failAnalysis( @@ -762,29 +885,44 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression( + AggregateExpression(function, mode, isDistinct), + spec: WindowSpecDefinition) => + val newChildren = function.children.map(extractExpr) + val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val newAgg = AggregateExpression(newFunction, mode, isDistinct) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute + + // Extracts other attributes + case attr: Attribute => extractExpr(attr) + }.asInstanceOf[NamedExpression] } @@ -873,6 +1011,7 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) @@ -928,7 +1067,8 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -957,59 +1097,66 @@ class Analyzer( } /** - * Removes all still-need-evaluate ordering expressions from sort and use an inner project to - * materialize them, finally use a outer project to project them away to keep the result same. - * Then we can make sure we only sort by [[AttributeReference]]s. - * - * As an example, - * {{{ - * Sort('a, 'b + 1, - * Relation('a, 'b)) - * }}} - * will be turned into: - * {{{ - * Project('a, 'b, - * Sort('a, '_sortCondition, - * Project('a, 'b, ('b + 1).as("_sortCondition"), - * Relation('a, 'b)))) - * }}} + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. */ - object RemoveEvaluationFromSort extends Rule[LogicalPlan] { - private def hasAlias(expr: Expression) = { - expr.find { - case a: Alias => true - case _ => false - }.isDefined + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case p => p transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } } + } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // The ordering expressions have no effect to the output schema of `Sort`, - // so `Alias`s in ordering expressions are unnecessary and we should remove them. - case s @ Sort(ordering, _, _) if ordering.exists(hasAlias) => - val newOrdering = ordering.map(_.transformUp { - case Alias(child, _) => child - }.asInstanceOf[SortOrder]) - s.copy(order = newOrdering) - - case s @ Sort(ordering, global, child) - if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => - - val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) - - val namedExpr = needEval.map(_.child match { - case n: NamedExpression => n - case e => Alias(e, "_sortCondition")() - }) - - val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => - order.copy(child = ne.toAttribute) - } + /** + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } - // Add still-need-evaluate ordering expressions into inner project and then project - // them away after the sort. - Project(child.output, - Sort(newOrdering, global, - Project(child.output ++ namedExpr, child))) + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"WindowFunction $wf requires window to be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } } } } @@ -1023,3 +1170,117 @@ object EliminateSubQueries extends Rule[LogicalPlan] { case Subquery(_, child) => child } } + +/** + * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level + * expression in Project(project list) or Aggregate(aggregate expressions) or + * Window(window expressions). + */ +object CleanupAliases extends Rule[LogicalPlan] { + private def trimAliases(e: Expression): Expression = { + var stop = false + e.transformDown { + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } + + def trimNonTopLevelAliases(e: Expression): Expression = e match { + case a: Alias => + Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + case other => trimAliases(other) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Project(projectList, child) => + val cleanedProjectList = + projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Project(cleanedProjectList, child) + + case Aggregate(grouping, aggs, child) => + val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) + + case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val cleanedWindowExprs = + windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) + Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + + case other => + var stop = false + other transformExpressionsDown { + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } +} + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} + +/** + * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. + */ +object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5766e6a2dd51..8f4ce74a2ea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} @@ -41,11 +42,9 @@ trait Catalog { val conf: CatalystConf - def tableExists(tableIdentifier: Seq[String]): Boolean + def tableExists(tableIdent: TableIdentifier): Boolean - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan + def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan /** * Returns tuples of (tableName, isTemporary) for all tables in the given database. @@ -55,68 +54,59 @@ trait Catalog { def refreshTable(tableIdent: TableIdentifier): Unit - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit + def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit - def unregisterTable(tableIdentifier: Seq[String]): Unit + def unregisterTable(tableIdent: TableIdentifier): Unit def unregisterAllTables(): Unit - protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { - if (conf.caseSensitiveAnalysis) { - tableIdentifier - } else { - tableIdentifier.map(_.toLowerCase) + /** + * Get the table name of TableIdentifier for temporary tables. + */ + protected def getTableName(tableIdent: TableIdentifier): String = { + // It is not allowed to specify database name for temporary tables. + // We check it here and throw exception if database is defined. + if (tableIdent.database.isDefined) { + throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + + "for temporary tables. If the table name has dots (.) in it, please quote the " + + "table name with backticks (`).") } - } - - protected def getDbTableName(tableIdent: Seq[String]): String = { - val size = tableIdent.size - if (size <= 2) { - tableIdent.mkString(".") + if (conf.caseSensitiveAnalysis) { + tableIdent.table } else { - tableIdent.slice(size - 2, size).mkString(".") + tableIdent.table.toLowerCase } } - - protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { - (tableIdent.lift(tableIdent.size - 2), tableIdent.last) - } } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new ConcurrentHashMap[String, LogicalPlan] + private[this] val tables = new ConcurrentHashMap[String, LogicalPlan] - override def registerTable( - tableIdentifier: Seq[String], - plan: LogicalPlan): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - tables.put(getDbTableName(tableIdent), plan) + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { + tables.put(getTableName(tableIdent), plan) } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - tables.remove(getDbTableName(tableIdent)) + override def unregisterTable(tableIdent: TableIdentifier): Unit = { + tables.remove(getTableName(tableIdent)) } override def unregisterAllTables(): Unit = { tables.clear() } - override def tableExists(tableIdentifier: Seq[String]): Boolean = { - val tableIdent = processTableIdentifier(tableIdentifier) - tables.containsKey(getDbTableName(tableIdent)) + override def tableExists(tableIdent: TableIdentifier): Boolean = { + tables.containsKey(getTableName(tableIdent)) } override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - val tableIdent = processTableIdentifier(tableIdentifier) - val tableFullName = getDbTableName(tableIdent) - val table = tables.get(tableFullName) + val tableName = getTableName(tableIdent) + val table = tables.get(tableName) if (table == null) { - sys.error(s"Table Not Found: $tableFullName") + throw new NoSuchTableException } - val tableWithQualifiers = Subquery(tableIdent.last, table) + val tableWithQualifiers = Subquery(tableName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. @@ -124,11 +114,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val result = ArrayBuffer.empty[(String, Boolean)] - for (name <- tables.keySet()) { - result += ((name, true)) - } - result + tables.keySet().asScala.map(_ -> true).toSeq } override def refreshTable(tableIdent: TableIdentifier): Unit = { @@ -143,64 +129,51 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { * lost when the JVM exits. */ trait OverrideCatalog extends Catalog { + private[this] val overrides = new ConcurrentHashMap[String, LogicalPlan] - // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() + private def getOverriddenTable(tableIdent: TableIdentifier): Option[LogicalPlan] = { + if (tableIdent.database.isDefined) { + None + } else { + Option(overrides.get(getTableName(tableIdent))) + } + } - abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.get(getDBTable(tableIdent)) match { + abstract override def tableExists(tableIdent: TableIdentifier): Boolean = { + getOverriddenTable(tableIdent) match { case Some(_) => true - case None => super.tableExists(tableIdentifier) + case None => super.tableExists(tableIdent) } } abstract override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - val tableIdent = processTableIdentifier(tableIdentifier) - val overriddenTable = overrides.get(getDBTable(tableIdent)) - val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) + getOverriddenTable(tableIdent) match { + case Some(table) => + val tableName = getTableName(tableIdent) + val tableWithQualifiers = Subquery(tableName, table) - // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are - // properly qualified with this alias. - val withAlias = - tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) + // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes + // are properly qualified with this alias. + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) + case None => super.lookupRelation(tableIdent, alias) + } } abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val dbName = if (conf.caseSensitiveAnalysis) { - databaseName - } else { - if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None - } - - val temporaryTables = overrides.filter { - // If a temporary table does not have an associated database, we should return its name. - case ((None, _), _) => true - // If a temporary table does have an associated database, we should return it if the database - // matches the given database name. - case ((db: Some[String], _), _) if db == dbName => true - case _ => false - }.map { - case ((_, tableName), _) => (tableName, true) - }.toSeq - - temporaryTables ++ super.getTables(databaseName) + overrides.keySet().asScala.map(_ -> true).toSeq ++ super.getTables(databaseName) } - override def registerTable( - tableIdentifier: Seq[String], - plan: LogicalPlan): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.put(getDBTable(tableIdent), plan) + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { + overrides.put(getTableName(tableIdent), plan) } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.remove(getDBTable(tableIdent)) + override def unregisterTable(tableIdent: TableIdentifier): Unit = { + if (tableIdent.database.isEmpty) { + overrides.remove(getTableName(tableIdent)) + } } override def unregisterAllTables(): Unit = { @@ -216,12 +189,12 @@ object EmptyCatalog extends Catalog { override val conf: CatalystConf = EmptyConf - override def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdent: TableIdentifier): Boolean = { throw new UnsupportedOperationException } override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { throw new UnsupportedOperationException } @@ -230,15 +203,17 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } - override def unregisterAllTables(): Unit = {} + override def unregisterAllTables(): Unit = { + throw new UnsupportedOperationException + } override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 187b238045f8..440f67991380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -47,6 +48,10 @@ trait CheckAnalysis { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case p if p.analyzed => // Skip already analyzed sub-plans + + case u: UnresolvedRelation => + u.failAnalysis(s"Table not found: ${u.tableIdentifier}") case operator: LogicalPlan => operator transformExpressionsUp { @@ -65,15 +70,32 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => + failAnalysis(s"Distinct window functions are not supported: $w") + + case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, + SpecifiedWindowFrame(frame, + FrameBoundary(l), + FrameBoundary(h)))) + if order.isEmpty || frame != RowFrame || l != h => + failAnalysis("An offset window function can only be evaluated in an ordered " + + s"row-based window frame with a single offset: $w") + + case w @ WindowExpression(e, s) => + // Only allow window functions with an aggregate expression or an offset window + // function. + e match { + case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + case _ => + failAnalysis(s"Expression '$e' not supported within a window function.") + } + // Make sure the window specification is valid. + s.validate match { + case Some(m) => + failAnalysis(s"Window specification $s is not valid because $m") + case None => w + } - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") } operator match { @@ -104,25 +126,50 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + aggExpr.aggregateFunction.children.foreach { child => + child.foreach { + case agg: AggregateExpression => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { + failAnalysis( + s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in the arguments of an aggregate function.") + } + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") + "Add to group by or wrap in first() (or first_value) if you don't care " + + "which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + // Check if the data type of expr is orderable. + if (!RowOrdering.isOrderable(expr.dataType)) { + failAnalysis( + s"expression ${expr.prettyString} cannot be used as a grouping expression " + + s"because its data type ${expr.dataType.simpleString} is not a orderable " + + s"data type.") + } + + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -136,6 +183,12 @@ trait CheckAnalysis { } } + case s @ SetOperation(left, right) if left.output.length != right.output.length => + failAnalysis( + s"${s.nodeName} can only be performed on tables with the same number of columns, " + + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") + case _ => // Fallbacks to the following checks } @@ -168,9 +221,12 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => + // The rule above is used to check Aggregate operator. failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: + s"""nondeterministic expressions are only allowed in + |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.prettyString).mkString(",")} |in operator ${operator.simpleString} """.stripMargin) @@ -179,5 +235,7 @@ trait CheckAnalysis { } } extendedCheckRules.foreach(_(plan)) + + plan.foreach(_.setAnalyzed()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala new file mode 100644 index 000000000000..4e7d1341028c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType + +/** + * This rule rewrites an aggregate query with distinct aggregations into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. + */ +case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p + // We need to wait until this Aggregate operator is resolved. + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) + }).asInstanceOf[AggregateFunction] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrLookup(x)) + } + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } + + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyString, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index bc0846646174..12c24cc76822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -51,23 +52,37 @@ class SimpleFunctionRegistry extends FunctionRegistry { private val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = { + override def registerFunction( + name: String, + info: ExpressionInfo, + builder: FunctionBuilder): Unit = synchronized { functionBuilders.put(name, (info, builder)) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - val func = functionBuilders.get(name).map(_._2).getOrElse { - throw new AnalysisException(s"undefined function $name") + val func = synchronized { + functionBuilders.get(name).map(_._2).getOrElse { + throw new AnalysisException(s"undefined function $name") + } } func(children) } - override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted + override def listFunction(): Seq[String] = synchronized { + functionBuilders.iterator.map(_._1).toList.sorted + } - override def lookupFunction(name: String): Option[ExpressionInfo] = { + override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { functionBuilders.get(name).map(_._1) } + + def copy(): SimpleFunctionRegistry = synchronized { + val registry = new SimpleFunctionRegistry + functionBuilders.iterator.foreach { case (name, (info, builder)) => + registry.registerFunction(name, info, builder) + } + registry + } } /** @@ -128,6 +143,7 @@ object FunctionRegistry { expression[Ceil]("ceil"), expression[Ceil]("ceiling"), expression[Cos]("cos"), + expression[Cosh]("cosh"), expression[Conv]("conv"), expression[EulerNumber]("e"), expression[Exp]("exp"), @@ -162,13 +178,26 @@ object FunctionRegistry { expression[ToRadians]("radians"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), + expression[Corr]("corr"), expression[Count]("count"), expression[First]("first"), + expression[First]("first_value"), expression[Last]("last"), + expression[Last]("last_value"), expression[Max]("max"), + expression[Average]("mean"), expression[Min]("min"), + expression[StddevSamp]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), + expression[VarianceSamp]("variance"), + expression[VariancePop]("var_pop"), + expression[VarianceSamp]("var_samp"), + expression[Skewness]("skewness"), + expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), @@ -177,8 +206,11 @@ object FunctionRegistry { expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), + expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), + expression[JsonTuple]("json_tuple"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[Length]("length"), @@ -201,6 +233,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[SubstringIndex]("substring_index"), + expression[StringTranslate]("translate"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), @@ -211,6 +244,7 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[CurrentTimestamp]("now"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), @@ -229,6 +263,7 @@ object FunctionRegistry { expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), + expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), @@ -238,6 +273,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), @@ -246,10 +282,21 @@ object FunctionRegistry { expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), - expression[InputFileName]("input_file_name") + expression[InputFileName]("input_file_name"), + expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank") ) - val builtin: FunctionRegistry = { + val builtin: SimpleFunctionRegistry = { val fr = new SimpleFunctionRegistry expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } fr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 422d42374702..dbcbd6854b47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -52,7 +53,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, @@ -144,7 +145,8 @@ object HiveTypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -163,7 +165,7 @@ object HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}") newType } } @@ -202,6 +204,7 @@ object HiveTypeCoercion { planName: String, left: LogicalPlan, right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + require(left.output.length == right.output.length) val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => @@ -225,16 +228,13 @@ object HiveTypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) - Union(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) - Except(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) - Intersect(newLeft, newRight) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p + + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => + val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) + s.makeCopy(Array(newLeft, newRight)) } } @@ -242,7 +242,7 @@ object HiveTypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -281,6 +281,12 @@ object HiveTypeCoercion { case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + // Checking NullType + case p @ BinaryComparison(left @ StringType(), right @ NullType()) => + p.makeCopy(Array(left, Literal.create(null, StringType))) + case p @ BinaryComparison(left @ NullType(), right @ StringType()) => + p.makeCopy(Array(Literal.create(null, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => p.makeCopy(Array(Cast(left, DoubleType), right)) case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => @@ -296,21 +302,39 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) } } /** - * Convert all expressions in in() list to the left operator type + * Convert the value and in list expressions to the common operator type + * by looking at all the argument types and finding the closest one that + * all the arguments can be cast to. When no common operator type is found + * the original expression will be returned and an Analysis Exception will + * be raised at type checking phase. */ object InConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) + findWiderCommonType(i.children.map(_.dataType)) match { + case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) + case None => i + } } } @@ -368,47 +392,60 @@ object HiveTypeCoercion { DecimalType.bounded(range + scale, scale) } - private def changePrecision(e: Expression, dataType: DataType): Expression = { - ChangeDecimalPrecision(Cast(e, dataType)) + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Add(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) - Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) + val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - Divide(changePrecision(e1, dt), changePrecision(e2, dt)) + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) + CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => @@ -442,8 +479,8 @@ object HiveTypeCoercion { * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ object BooleanEquality extends Rule[LogicalPlan] { - private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1)) - private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0)) + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(numericExpr, Seq( @@ -466,7 +503,7 @@ object HiveTypeCoercion { )) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -508,7 +545,7 @@ object HiveTypeCoercion { * truncated version of this number. */ object StringToIntegralCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -521,7 +558,7 @@ object HiveTypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -537,12 +574,6 @@ object HiveTypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) @@ -563,6 +594,20 @@ object HiveTypeCoercion { case None => c } + case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => @@ -575,7 +620,7 @@ object HiveTypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.resolved => e @@ -592,10 +637,10 @@ object HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -612,7 +657,7 @@ object HiveTypeCoercion { case c: CaseKeyWhen if c.childrenResolved && !c.resolved => val maybeCommonType = - findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) + findWiderCommonType((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => @@ -628,7 +673,8 @@ object HiveTypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => @@ -652,7 +698,7 @@ object HiveTypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -669,7 +715,7 @@ object HiveTypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 35b74024a4ca..394be47a588b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 03da45b09f92..4f89b462a6ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.errors +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.{TableIdentifier, errors} +import org.apache.spark.sql.types.{DataType, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -35,11 +36,11 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. */ case class UnresolvedRelation( - tableIdentifier: Seq[String], + tableIdentifier: TableIdentifier, alias: Option[String] = None) extends LeafNode { /** Returns a `.` separated name for this relation. */ - def tableName: String = tableIdentifier.mkString(".") + def tableName: String = tableIdentifier.unquotedString override def output: Seq[Attribute] = Nil @@ -69,8 +70,64 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } object UnresolvedAttribute { + /** + * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). + */ def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + + /** + * Creates an [[UnresolvedAttribute]], from a single quoted string (for example using backticks in + * HiveQL. Since the string is consider quoted, no processing is done on the name. + */ def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) + + /** + * Creates an [[UnresolvedAttribute]] from a string in an embedded language. In this case + * we treat it as a quoted identifier, except for '.', which must be further quoted using + * backticks if it is part of a column name. + */ + def quotedString(name: String): UnresolvedAttribute = + new UnresolvedAttribute(parseAttributeName(name)) + + /** + * Used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * Escape character is not supported now, so we can't use backtick inside name part. + */ + def parseAttributeName(name: String): Seq[String] = { + def e = new AnalysisException(s"syntax error in attribute name: $name") + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (name(i - 1) == '.' || i == name.length - 1) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } } case class UnresolvedFunction( @@ -84,6 +141,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false + override def prettyString: String = { + s"${name}(${children.map(_.prettyString).mkString(",")})" + } + override def toString: String = s"'$name(${children.mkString(",")})" } @@ -101,7 +162,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] + def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] } @@ -109,26 +170,56 @@ abstract class Star extends LeafExpression with NamedExpression { * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". * - * @param table an optional table that should be the target of the expansion. If omitted all - * tables' columns are produced. + * This is also used to expand structs. For example: + * "SELECT record.* from (SELECT struct(a,b,c) as record ...) + * + * @param target an optional name that should be the target of the expansion. If omitted all + * targets' columns are produced. This can either be a table name or struct name. This + * is a list of identifiers that is the path of the expansion. */ -case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { +case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { + + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { - val expandedAttributes: Seq[Attribute] = table match { + // First try to expand assuming it is table.*. + val expandedAttributes: Seq[Attribute] = target match { // If there is no table specified, use all input attributes. - case None => input + case None => input.output // If there is a table, pick out attributes that are part of this table. - case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) + case Some(t) => if (t.size == 1) { + input.output.filter(_.qualifiers.exists(resolver(_, t.head))) + } else { + List() + } } - expandedAttributes.zip(input).map { - case (n: NamedExpression, _) => n - case (e, originalAttribute) => - Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + if (expandedAttributes.nonEmpty) return expandedAttributes + + // Try to resolve it as a struct expansion. If there is a conflict and both are possible, + // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. + require(target.isDefined) + val attribute = input.resolve(target.get, resolver) + if (attribute.isDefined) { + // This target resolved to an attribute in child. It must be a struct. Expand it. + attribute.get.dataType match { + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) + Alias(extract, f.name)() + } + + case _ => { + throw new AnalysisException("Can only star expand struct data types. Attribute: `" + + target.get + "`") + } + } + } else { + val from = input.inputSet.map(_.name).mkString(", ") + val targetString = target.get.mkString(".") + throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") } } - override def toString: String = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = target.map(_ + ".").getOrElse("") + "*" } /** @@ -168,7 +259,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a4932765..8102c93c6f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ @@ -144,17 +145,18 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) @@ -225,9 +227,10 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) } implicit class DslAttribute(a: AttributeReference) { @@ -273,17 +276,19 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + analysis.UnresolvedRelation(TableIdentifier(tableName)), + Map.empty, logicalPlan, overwrite, false) def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala new file mode 100644 index 000000000000..7a4401cf5810 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.util.Utils +import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} + +/** + * A factory for constructing encoders that convert objects and primitives to and from the + * internal row format using catalyst expressions and code generation. By default, the + * expressions used to retrieve values from an input row when producing an object will be created as + * follows: + * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions + * and [[UnresolvedExtractValue]] expressions. + * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions. + * - Primitives will have their values extracted from the first ordinal with a schema that defaults + * to the name `value`. + */ +object ExpressionEncoder { + def apply[T : TypeTag](): ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) + + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] + + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + + new ExpressionEncoder[T]( + schema, + flat, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } + + // TODO: improve error message for java bean encoder. + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val toRowExpression = JavaTypeInference.extractorsFor(beanClass) + val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](beanClass)) + } + + /** + * Given a set of N encoders, constructs a new encoder that produce objects as items in an + * N-tuple. Note that these encoders should be unresolved so that information about + * name/positional binding is preserved. + */ + def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + encoders.foreach(_.assertUnresolved()) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val toRowExpressions = encoders.map { + case e if e.flat => e.toRowExpressions.head + case other => CreateStruct(other.toRowExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } + } + + val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.fromRowExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + val input = BoundReference(index, enc.schema, nullable = true) + enc.fromRowExpression.transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(input, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + } + } + } + + val fromRowExpression = + NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + flat = false, + toRowExpressions, + fromRowExpression, + ClassTag(cls)) + } + + def tuple[T1, T2]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = + tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] + + def tuple[T1, T2, T3]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = + tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + + def tuple[T1, T2, T3, T4]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = + tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + + def tuple[T1, T2, T3, T4, T5]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4], + e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = + tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] +} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param toRowExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. + * @param clsTag A classtag for `T`. + */ +case class ExpressionEncoder[T]( + schema: StructType, + flat: Boolean, + toRowExpressions: Seq[Expression], + fromRowExpression: Expression, + clsTag: ClassTag[T]) + extends Encoder[T] { + + if (flat) require(toRowExpressions.size == 1) + + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) + + @transient + private lazy val inputRow = new GenericMutableRow(1) + + @transient + private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + + /** + * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to + * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should + * copy the result before making another call if required. + */ + def toRow(t: T): InternalRow = try { + inputRow(0) = t + extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException( + s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) + } + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * function. + */ + def fromRow(row: InternalRow): T = try { + constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + } + + /** + * The process of resolution to a given schema throws away information about where a given field + * is being bound by ordinal instead of by name. This method checks to make sure this process + * has not been done already in places where we plan to do later composition of encoders. + */ + def assertUnresolved(): Unit = { + (fromRowExpression +: toRowExpressions).foreach(_.foreach { + case a: AttributeReference if a.name != "loopVar" => + sys.error(s"Unresolved encoder expected, but $a was found.") + case _ => + }) + } + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + var maxOrdinal = -1 + fromRowExpression.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + + val unbound = fromRowExpression transform { + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } + } + + val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + SimpleAnalyzer.checkAnalysis(analyzedPlan) + val optimizedPlan = SimplifyCasts(analyzedPlan) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) + } + + /** + * Returns a copy of this encoder where the expressions used to construct an object from an input + * row have been bound to the ordinals of the given schema. Note that you need to first call + * resolve before bind. + */ + def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) + } + + /** + * Returns a new encoder with input columns shifted by `delta` ordinals + */ + def shift(delta: Int): ExpressionEncoder[T] = { + copy(fromRowExpression = fromRowExpression transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) + }) + } + + protected val attrs = toRowExpressions.flatMap(_.collect { + case _: UnresolvedAttribute => "" + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }) + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 000000000000..a753b187bcd3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala new file mode 100644 index 000000000000..d34ec9408ae1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.Map +import scala.reflect.ClassTag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A factory for constructing encoders that convert external row to/from the Spark SQL + * internal binary representation. + */ +object RowEncoder { + def apply(schema: StructType): ExpressionEncoder[Row] = { + val cls = classOf[Row] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpressions = extractorsFor(inputObject, schema) + val constructExpression = constructorFor(schema) + new ExpressionEncoder[Row]( + schema, + flat = false, + extractExpressions.asInstanceOf[CreateStruct].children, + constructExpression, + ClassTag(cls)) + } + + private def extractorsFor( + inputObject: Expression, + inputType: DataType): Expression = inputType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => inputObject + + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + + case TimestampType => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case DateType => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case _: DecimalType => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case StringType => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t @ ArrayType(et, _) => et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = t) + case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et)) + } + + case t @ MapType(kt, vt, valueNullable) => + val keys = + Invoke( + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = t) + + case StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + val method = if (f.dataType.isInstanceOf[StructType]) { + "getStruct" + } else { + "get" + } + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + extractorsFor( + Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), + f.dataType)) + } + CreateStruct(convertedFields) + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case _ if ScalaReflection.isNativeType(dt) => dt + case TimestampType => ObjectType(classOf[java.sql.Timestamp]) + case DateType => ObjectType(classOf[java.sql.Date]) + case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) + case StringType => ObjectType(classOf[java.lang.String]) + case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) + case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) + case _: StructType => ObjectType(classOf[Row]) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) + } + + private def constructorFor(schema: StructType): Expression = { + val fields = schema.zipWithIndex.map { case (f, i) => + val field = BoundReference(i, f.dataType, f.nullable) + If( + IsNull(field), + Literal.create(null, externalDataTypeFor(f.dataType)), + constructorFor(BoundReference(i, f.dataType, f.nullable)) + ) + } + CreateExternalRow(fields) + } + + private def constructorFor(input: Expression): Expression = input.dataType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => input + + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + + case TimestampType => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + input :: Nil) + + case DateType => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + input :: Nil) + + case _: DecimalType => + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case StringType => + Invoke(input, "toString", ObjectType(classOf[String])) + + case ArrayType(et, nullable) => + val arrayData = + Invoke( + MapObjects(constructorFor(_), input, et), + "array", + ObjectType(classOf[Array[_]])) + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case MapType(kt, vt, valueNullable) => + val keyArrayType = ArrayType(kt, false) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) + + val valueArrayType = ArrayType(vt, valueNullable) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + If( + Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, externalDataTypeFor(f.dataType)), + constructorFor(GetStructField(input, i))) + } + CreateExternalRow(convertedFields) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala new file mode 100644 index 000000000000..9e283f5eb634 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference + +package object encoders { + /** + * Returns an internal encoder object that can be used to serialize / deserialize JVM objects + * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute + * references from a specific schema.) This requirement allows us to preserve whether a given + * object type is being bound by name or by ordinal when doing resolution. + */ + private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + case e: ExpressionEncoder[A] => + e.assertUnresolved() + e + case _ => sys.error(s"Only expression encoders are supported today") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 96a11e352ec5..ef3cc554b79c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,6 +26,13 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } + + /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */ + def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex) + + /** Given a schema, constructs a map from ordinal to Attribute. */ + def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] = + schema.zipWithIndex.map { case (a, i) => i -> a }.toMap } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 5345696570b4..383153557420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -31,6 +31,10 @@ protected class AttributeEquals(val a: Attribute) { } object AttributeSet { + /** Returns an empty [[AttributeSet]]. */ + val empty = apply(Iterable.empty) + + /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 473b9b787058..ff1f28ddbbf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -68,10 +68,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue("i", dataType, ordinal.toString) + val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 88429bb84b1e..cb60d5958d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,12 +22,10 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import scala.collection.mutable - object Cast { @@ -106,8 +104,9 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { + + override def toString: String = s"cast($child as ${dataType.simpleString})" override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { @@ -120,8 +119,6 @@ case class Cast(child: Expression, dataType: DataType) override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable - override def toString: String = s"CAST($child, $dataType)" - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -142,7 +139,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -157,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal(0)) + buildCast[Decimal](_, !_.isZero) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -198,8 +203,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -311,19 +316,19 @@ case class Cast(child: Expression, dataType: DataType) case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) - case DecimalType() => + case dt: DecimalType => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) - case LongType => - b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case t: IntegralType => + b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) + case x: FractionalType => b => try { - changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { case _: NumberFormatException => null } @@ -432,7 +437,7 @@ case class Cast(child: Expression, dataType: DataType) val eval = child.gen(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) eval.code + - castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast) + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast) } // three function arguments are: child.primitive, result.primitive and result.isNull @@ -449,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) - case decimal: DecimalType => castToDecimalCode(from, decimal) + case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) @@ -530,17 +535,18 @@ case class Cast(child: Expression, dataType: DataType) } """ - private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + private[this] def castToDecimalCode( + from: DataType, + target: DecimalType, + ctx: CodeGenContext): CastFunction = { + val tmp = ctx.freshName("tmpDecimal") from match { case StringType => (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - new scala.math.BigDecimal( - new java.math.BigDecimal($c.toString()))); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -548,13 +554,8 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = null; - if ($c) { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); - } else { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); - } - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); + ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive @@ -563,33 +564,29 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c.clone(); + ${changePrecision(tmp, target, evPrim, evNull)} """ - case LongType => + case x: IntegralType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set($c); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply((long) $c); + ${changePrecision(tmp, target, evPrim, evNull)} """ - case x: NumericType => + case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -649,14 +646,24 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => @@ -907,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) """ } } + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable { + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 000000000000..f7162e420d19 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override val hashCode: Int = e.semanticHash() + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns true if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get += expr + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + if (!skip && !addExpr(root)) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all of the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.getOrElse(Expr(e), mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the data structure as a string. If `all` is false, skips sets of + * equivalent expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => { + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + }} + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ef2fc2e8c29d..6d807c9ecf30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -92,12 +92,19 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + ctx.subExprEliminationExprs.get(this).map { subExprState => + // This expression is repeated meaning the code to evaluated has already been added + // as a function and called in advance. Just use it. + val code = s"/* ${this.toCommentSafeString} */" + GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) + }.getOrElse { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + // Add `this` in the comment. + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + } } /** @@ -145,11 +152,37 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + // Non-deterministic expressions cannot be semantic equal + if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq checkSemantic(elements1, elements2) } + /** + * Returns the hash for this expression. Expressions that compute the same result, even if + * they differ cosmetically should return the same hash. + */ + def semanticHash() : Int = { + def computeHash(e: Seq[Any]): Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var hash: Int = 17 + e.foreach(i => { + val h: Int = i match { + case e: Expression => e.semanticHash() + case Some(e: Expression) => e.semanticHash() + case t: Traversable[_] => computeHash(t.toSeq) + case null => 0 + case other => other.hashCode() + } + hash = hash * 37 + h + }) + hash + } + + computeHash(this.productIterator.toSeq) + } + /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. @@ -169,12 +202,27 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } - override def toString: String = prettyName + children.mkString("(", ",", ")") + private def flatArguments = productIterator.flatMap { + case t: Traversable[_] => t + case single => single :: Nil + } + + override def simpleString: String = toString + + override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. + */ + protected def toCommentSafeString: String = this.toString + .replace("*/", "\\*\\/") + .replace("\\u", "\\\\u") } @@ -276,7 +324,7 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { nullSafeCodeGen(ctx, ev, eval => { - s"${ev.primitive} = ${f(eval)};" + s"${ev.value} = ${f(eval)};" }) } @@ -292,10 +340,10 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(eval.primitive) + val resultCode = f(eval.value) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { $resultCode } @@ -357,7 +405,7 @@ abstract class BinaryExpression extends Expression { ev: GeneratedExpressionCode, f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s"${ev.primitive} = ${f(eval1, eval2)};" + s"${ev.value} = ${f(eval1, eval2)};" }) } @@ -375,11 +423,11 @@ abstract class BinaryExpression extends Expression { f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitive, eval2.primitive) + val resultCode = f(eval1.value, eval2.value) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${eval2.code} if (!${eval2.isNull}) { @@ -444,7 +492,7 @@ abstract class TernaryExpression extends Expression { override def nullable: Boolean = children.exists(_.nullable) /** - * Default behavior of evaluation according to the default nullability of BinaryExpression. + * Default behavior of evaluation according to the default nullability of TernaryExpression. * If subclass of BinaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { @@ -463,7 +511,7 @@ abstract class TernaryExpression extends Expression { } /** - * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default * nullability, they can override this method to save null-check code. If we need full control * of evaluation process, we should override [[eval]]. */ @@ -482,7 +530,7 @@ abstract class TernaryExpression extends Expression { ev: GeneratedExpressionCode, f: (String, String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { - s"${ev.primitive} = ${f(eval1, eval2, eval3)};" + s"${ev.value} = ${f(eval1, eval2, eval3)};" }) } @@ -499,11 +547,11 @@ abstract class TernaryExpression extends Expression { ev: GeneratedExpressionCode, f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive) + val resultCode = f(evals(0).value, evals(1).value, evals(2).value) s""" ${evals(0).code} boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${evals(0).isNull}) { ${evals(1).code} if (!${evals(1).isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala deleted file mode 100644 index 3caf0fb3410c..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types._ - -case class FromUnsafe(child: Expression) extends UnaryExpression - with ExpectsInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(ArrayType, StructType, MapType)) - - override def dataType: DataType = child.dataType - - private def convert(value: Any, dt: DataType): Any = dt match { - case StructType(fields) => - val row = value.asInstanceOf[UnsafeRow] - val result = new Array[Any](fields.length) - fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => - if (!row.isNullAt(i)) { - result(i) = convert(row.get(i, dt), dt) - } - } - new GenericInternalRow(result) - - case ArrayType(elementType, _) => - val array = value.asInstanceOf[UnsafeArrayData] - val length = array.numElements() - val result = new Array[Any](length) - var i = 0 - while (i < length) { - if (!array.isNullAt(i)) { - result(i) = convert(array.get(i, elementType), elementType) - } - i += 1 - } - new GenericArrayData(result) - - case MapType(kt, vt, _) => - val map = value.asInstanceOf[UnsafeMapData] - val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] - val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] - new ArrayBasedMapData(safeKeyArray, safeValueArray) - - case _ => value - } - - override def nullSafeEval(input: Any): Any = { - convert(input, dataType) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala deleted file mode 100644 index 6e957928e02a..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - -trait GenericSpecializedGetters extends SpecializedGetters { - - def genericGet(ordinal: Int): Any - - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - - override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) - - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - - override def getByte(ordinal: Int): Byte = getAs(ordinal) - - override def getShort(ordinal: Int): Short = getAs(ordinal) - - override def getInt(ordinal: Int): Int = getAs(ordinal) - - override def getLong(ordinal: Int): Long = getAs(ordinal) - - override def getFloat(ordinal: Int): Float = getAs(ordinal) - - override def getDouble(ordinal: Int): Double = getAs(ordinal) - - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - - override def getMap(ordinal: Int): MapData = getAs(ordinal) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 1e74f716955e..bf215783fc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{DataType, StringType} @@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + SqlNewHadoopRDDState.getInputFileName() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + s"final ${ctx.javaType(dataType)} ${ev.value} = " + + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala new file mode 100644 index 000000000000..935c3aa28c99 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + + +/** + * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to + * be instantiated once per thread and reused. + */ +class JoinedRow extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ + + def this(left: InternalRow, right: InternalRow) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: InternalRow): JoinedRow = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: InternalRow): JoinedRow = { + row2 = newRight + this + } + + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } + + override def numFields: Int = row1.numFields + row2.numFields + + override def get(i: Int, dt: DataType): AnyRef = + if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) + + override def isNullAt(i: Int): Boolean = + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) + + override def getBoolean(i: Int): Boolean = + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) + + override def getByte(i: Int): Byte = + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) + + override def getShort(i: Int): Short = + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) + + override def getInt(i: Int): Int = + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) + + override def getLong(i: Int): Long = + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) + + override def getFloat(i: Int): Float = + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + + override def getDouble(i: Int): Double = + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) + + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) { + row1.getDecimal(i, precision, scale) + } else { + row2.getDecimal(i - row1.numFields, precision, scale) + } + } + + override def getUTF8String(i: Int): UTF8String = + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) + + override def getBinary(i: Int): Array[Byte] = + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) + + override def getArray(i: Int): ArrayData = + if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) + + override def getInterval(i: Int): CalendarInterval = + if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + + override def getMap(i: Int): MapData = + if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) + + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) + } else { + row2.getStruct(i - row1.numFields, numFields) + } + } + + override def anyNull: Boolean = row1.anyNull || row2.anyNull + + override def copy(): InternalRow = { + val copy1 = row1.copy() + val copy2 = row2.copy() + new JoinedRow(copy1, copy2) + } + + override def toString: String = { + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.toString + } else if (row2 eq null) { + row1.toString + } else { + s"{${row1.toString} + ${row2.toString}}" + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 291b7a5bc3af..2d7679fdfe04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -66,7 +66,7 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with ev.isNull = "false" s""" - final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 79649741025a..053e612f3ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + private[this] val buffer = new Array[Any](expressions.size) + expressions.foreach(_.foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).eval(input) + // Store the result into buffer first, to make the projection atomic (needed by aggregation) + buffer(i) = exprArray(i).eval(input) + i += 1 + } + i = 0 + while (i < exprArray.length) { + mutableRow(i) = buffer(i) i += 1 } mutableRow @@ -95,16 +102,6 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { - /* - * Returns whether UnsafeProjection can support given StructType, Array[DataType] or - * Seq[Expression]. - */ - def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) - private def canSupport(types: Array[DataType]): Boolean = { - types.forall(GenerateUnsafeProjection.canSupport) - } - /** * Returns an UnsafeProjection for given StructType. */ @@ -121,7 +118,11 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -133,6 +134,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** @@ -152,13 +169,7 @@ object FromUnsafeProjection { */ def apply(fields: Seq[DataType]): Projection = { create(fields.zipWithIndex.map(x => { - val b = new BoundReference(x._2, x._1, true) - // todo: this is quite slow, maybe remove this whole projection after remove generic getter of - // InternalRow? - b.dataType match { - case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) - case _ => b - } + new BoundReference(x._2, x._1, true) })) } @@ -169,118 +180,3 @@ object FromUnsafeProjection { GenerateSafeProjection.generate(exprs) } } - -/** - * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to - * be instantiated once per thread and reused. - */ -class JoinedRow extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def genericGet(i: Int): Any = - if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { - if (i < row1.numFields) { - row1.getDecimal(i, precision, scale) - } else { - row2.getDecimal(i - row1.numFields, precision, scale) - } - } - - override def getUTF8String(i: Int): UTF8String = - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - - override def getBinary(i: Int): Array[Byte] = - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - - override def getArray(i: Int): ArrayData = - if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) - - override def getInterval(i: Int): CalendarInterval = - if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) - - override def getMap(i: Int): MapData = - if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) - - override def getStruct(i: Int, numFields: Int): InternalRow = { - if (i < row1.numFields) { - row1.getStruct(i, numFields) - } else { - row2.getStruct(i - row1.numFields, numFields) - } - } - - override def copy(): InternalRow = { - val copy1 = row1.copy() - val copy2 = row2.copy() - new JoinedRow(copy1, copy2) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 11c7950c0613..85faa19bbf5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,19 +19,25 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true @@ -60,6 +66,10 @@ case class ScalaUDF( */ + // Accessors used in genCode + def userDefinedFunc(): AnyRef = function + def getChildren(): Seq[Expression] = children + private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] @@ -960,6 +970,89 @@ case class ScalaUDF( } // scalastyle:on + + // Generate codes used to convert the arguments to Scala type for user-defined funtions + private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + val scalaUDFClassName = classOf[ScalaUDF].getName + + val converterTerm = ctx.freshName("converter") + val expressionIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, converterTerm, + s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + + s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());") + converterTerm + } + + override def genCode( + ctx: CodeGenContext, + ev: GeneratedExpressionCode): String = { + + ctx.references += this + + val scalaUDFClassName = classOf[ScalaUDF].getName + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + + // Generate codes used to convert the returned value of user-defined functions to Catalyst type + val catalystConverterTerm = ctx.freshName("catalystConverter") + val catalystConverterTermIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToCatalystConverter((($scalaUDFClassName)expressions" + + s"[$catalystConverterTermIdx]).dataType());") + + val resultTerm = ctx.freshName("result") + + // This must be called before children expressions' codegen + // because ctx.references is used in genCodeForConverter + val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + + // Initialize user-defined function + val funcClassName = s"scala.Function${children.size}" + + val funcTerm = ctx.freshName("udf") + val funcExpressionIdx = ctx.references.size - 1 + ctx.addMutableState(funcClassName, funcTerm, + s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" + + s"[$funcExpressionIdx]).userDefinedFunc());") + + // codegen for children expressions + val evals = children.map(_.gen(ctx)) + + // Generate the codes for expressions and calling user-defined function + // We need to get the boxedType of dataType's javaType here. Because for the dataType + // such as IntegerType, its javaType is `int` and the returned type of user-defined + // function is Object. Trying to convert an Object to `int` will cause casting exception. + val evalCode = evals.map(_.code).mkString + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip + + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + + s""" + $evalCode + ${converters.mkString("\n")} + $callFunc + + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } + """ + } + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index f6a872ba446e..290c128d65b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection @@ -62,7 +63,8 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val childCode = child.child.gen(ctx) - val input = childCode.primitive + val input = childCode.value + val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName val (nullValue: Long, prefixCode: String) = child.child.dataType match { @@ -76,6 +78,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), s"$DoublePrefixCmp.computePrefix((double)$input)") case StringType => (0L, s"$input.getPrefix()") + case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)") case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) { s"$input.toUnscaledLong()" @@ -94,10 +97,10 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { childCode.code + s""" - |long ${ev.primitive} = ${nullValue}L; + |long ${ev.value} = ${nullValue}L; |boolean ${ev.isNull} = false; |if (!${childCode.isNull}) { - | ${ev.primitive} = $prefixCode; + | ${ev.value} = $prefixCode; |} """.stripMargin } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 4b1772a2deed..8bff173d64eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -47,6 +47,6 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" + s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index d149a5b179b6..475cbe005a6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * val newCopy = new Mutable$tpe * newCopy.isNull = isNull * newCopy.value = value - * newCopy.asInstanceOf[this.type] + * newCopy * } * }""" * }.foreach(println) @@ -78,7 +78,7 @@ final class MutableInt extends MutableValue { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableInt] + newCopy } } @@ -93,7 +93,7 @@ final class MutableFloat extends MutableValue { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableFloat] + newCopy } } @@ -108,7 +108,7 @@ final class MutableBoolean extends MutableValue { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableBoolean] + newCopy } } @@ -123,7 +123,7 @@ final class MutableDouble extends MutableValue { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableDouble] + newCopy } } @@ -138,7 +138,7 @@ final class MutableShort extends MutableValue { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableShort] + newCopy } } @@ -153,7 +153,7 @@ final class MutableLong extends MutableValue { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableLong] + newCopy } } @@ -168,7 +168,7 @@ final class MutableByte extends MutableValue { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableByte] + newCopy } } @@ -183,7 +183,7 @@ final class MutableAny extends MutableValue { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableAny] + newCopy } } @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } @@ -232,7 +231,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericInternalRow(newValues) } - override def genericGet(i: Int): Any = values(i).boxed + override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { if (value == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala new file mode 100644 index 000000000000..94ac4bf09b90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends DeclarativeAggregate { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) + case _ => DoubleType + } + + private lazy val sumDataType = child.dataType match { + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case _ => DoubleType + } + + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() + + override lazy val aggBufferAttributes = sum :: count :: Nil + + override lazy val initialValues = Seq( + /* sum = */ Cast(Literal(0), sumDataType), + /* count = */ Literal(0L) + ) + + override lazy val updateExpressions = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val mergeExpressions = Seq( + /* sum = */ sum.left + sum.right, + /* count = */ count.left + count.right + ) + + // If all input are nulls, count will be 0 and we will get null after the division. + override lazy val evaluateExpression = child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Cast(sum, dt) / Cast(count, dt), resultType) + case _ => + Cast(sum, resultType) / Cast(count, resultType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala new file mode 100644 index 000000000000..d07d4c338cdf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * A central moment is the expected value of a specified power of the deviation of a random + * variable from the mean. Central moments are often used to characterize the properties of about + * the shape of a distribution. + * + * This class implements online, one-pass algorithms for computing the central moments of a set of + * points. + * + * Behavior: + * - null values are ignored + * - returns `Double.NaN` when the column contains `Double.NaN` values + * + * References: + * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." + * 2015. http://arxiv.org/abs/1510.04923 + * + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + * Algorithms for calculating variance (Wikipedia)]] + * + * @param child to compute central moments of. + */ +abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { + + /** + * The central moment order to be computed. + */ + protected def momentOrder: Int + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * Size of aggregation buffer. + */ + private[this] val bufferSize = 5 + + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => + AttributeReference(s"M$i", DoubleType)() + } + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // buffer offsets + private[this] val nOffset = mutableAggBufferOffset + private[this] val meanOffset = mutableAggBufferOffset + 1 + private[this] val secondMomentOffset = mutableAggBufferOffset + 2 + private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 + private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 + + // frequently used values for online updates + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + private[this] var n = 0.0 + private[this] var mean = 0.0 + private[this] var m2 = 0.0 + private[this] var m3 = 0.0 + private[this] var m4 = 0.0 + + /** + * Initialize all moments to zero. + */ + override def initialize(buffer: MutableRow): Unit = { + for (aggIndex <- 0 until bufferSize) { + buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + } + } + + /** + * Update the central moments buffer. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = Cast(child, DoubleType).eval(input) + if (v != null) { + val updateValue = v match { + case d: Double => d + } + + n = buffer.getDouble(nOffset) + mean = buffer.getDouble(meanOffset) + + n += 1.0 + buffer.setDouble(nOffset, n) + delta = updateValue - mean + deltaN = delta / n + mean += deltaN + buffer.setDouble(meanOffset, mean) + + if (momentOrder >= 2) { + m2 = buffer.getDouble(secondMomentOffset) + m2 += delta * (delta - deltaN) + buffer.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + delta2 = delta * delta + deltaN2 = deltaN * deltaN + m3 = buffer.getDouble(thirdMomentOffset) + m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) + buffer.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + m4 = buffer.getDouble(fourthMomentOffset) + m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + + delta * (delta * delta2 - deltaN * deltaN2) + buffer.setDouble(fourthMomentOffset, m4) + } + } + } + + /** + * Merge two central moment buffers. + */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val n1 = buffer1.getDouble(nOffset) + val n2 = buffer2.getDouble(inputAggBufferOffset) + val mean1 = buffer1.getDouble(meanOffset) + val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + + var secondMoment1 = 0.0 + var secondMoment2 = 0.0 + + var thirdMoment1 = 0.0 + var thirdMoment2 = 0.0 + + var fourthMoment1 = 0.0 + var fourthMoment2 = 0.0 + + n = n1 + n2 + buffer1.setDouble(nOffset, n) + delta = mean2 - mean1 + deltaN = if (n == 0.0) 0.0 else delta / n + mean = mean1 + deltaN * n2 + buffer1.setDouble(mutableAggBufferOffset + 1, mean) + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + if (momentOrder >= 2) { + secondMoment1 = buffer1.getDouble(secondMomentOffset) + secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + thirdMoment1 = buffer1.getDouble(thirdMomentOffset) + thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) + buffer1.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + fourthMoment1 = buffer1.getDouble(fourthMomentOffset) + fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * + (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) + buffer1.setDouble(fourthMomentOffset, m4) + } + } + + /** + * Compute aggregate statistic from sufficient moments. + * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) + * needed to compute the aggregate stat. + */ + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any + + override final def eval(buffer: InternalRow): Any = { + val n = buffer.getDouble(nOffset) + val mean = buffer.getDouble(meanOffset) + val moments = Array.ofDim[Double](momentOrder + 1) + moments(0) = 1.0 + moments(1) = 0.0 + if (momentOrder >= 2) { + moments(2) = buffer.getDouble(secondMomentOffset) + } + if (momentOrder >= 3) { + moments(3) = buffer.getDouble(thirdMomentOffset) + } + if (momentOrder >= 4) { + moments(4) = buffer.getDouble(fourthMomentOffset) + } + + getStatistic(n, mean, moments) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala new file mode 100644 index 000000000000..00d7436b710d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * Compute Pearson correlation between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + */ +case class Corr( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { + + def this(left: Expression, right: Expression) = + this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def children: Seq[Expression] = Seq(left, right) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"corr requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = { + aggBufferAttributes.map(_.newInstance()) + } + + override val aggBufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("MkX", DoubleType)(), + AttributeReference("MkY", DoubleType)(), + AttributeReference("count", LongType)()) + + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 + private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 + private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 + private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 + private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 + private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 + private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 + private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 + private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def initialize(buffer: MutableRow): Unit = { + buffer.setDouble(mutableAggBufferOffset, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) + buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(mutableAggBufferOffset) + var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer.getLong(mutableAggBufferOffsetPlus5) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + + buffer.setDouble(mutableAggBufferOffset, xAvg) + buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) + + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. + if (count2 > 0) { + var xAvg = buffer1.getDouble(mutableAggBufferOffset) + var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer1.getLong(mutableAggBufferOffsetPlus5) + + val xAvg2 = buffer2.getDouble(inputAggBufferOffset) + val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) + val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) + val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) + val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 + MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 + count = totalCount + + buffer1.setDouble(mutableAggBufferOffset, xAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer1.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(mutableAggBufferOffsetPlus5) + if (count > 0) { + val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + val corr = Ck / math.sqrt(MkX * MkY) + if (corr.isNaN) { + null + } else { + corr + } + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala new file mode 100644 index 000000000000..441f52ab5ca5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) + + private lazy val count = AttributeReference("count", LongType)() + + override lazy val aggBufferAttributes = count :: Nil + + override lazy val initialValues = Seq( + /* count = */ Literal(0L) + ) + + override lazy val updateExpressions = Seq( + /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) + ) + + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = Cast(count, LongType) + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +object Count { + def apply(child: Expression): Count = Count(child :: Nil) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala new file mode 100644 index 000000000000..35f57426feaf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Returns the first value of `child` for a group of rows. If the first value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + */ +case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private lazy val first = AttributeReference("first", child.dataType)() + + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( + /* first = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) + ) + + override lazy val updateExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet, IsNull(child)), first, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) + ) + } else { + Seq( + /* first = */ If(valueSet, first, child), + /* valueSet = */ Literal.create(true, BooleanType) + ) + } + } + + override lazy val mergeExpressions: Seq[Expression] = { + // For first, we can just check if valueSet.left is set to true. If it is set + // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // false, we are safe to do so because first.right will be null in this case). + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ Or(valueSet.left, valueSet.right) + ) + } + + override lazy val evaluateExpression: AttributeReference = first + + override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala new file mode 100644 index 000000000000..e1fd22e36764 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.lang.{Long => JLong} +import java.util + +import com.clearspring.analytics.hash.MurmurHash + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +// scalastyle:off +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. This class + * implements the dense version of the HLL++ algorithm as an Aggregate Function. + * + * This implementation has been based on the following papers: + * HyperLogLog: the analysis of a near-optimal cardinality estimation algorithm + * http://algo.inria.fr/flajolet/Publications/FlFuGaMe07.pdf + * + * HyperLogLog in Practice: Algorithmic Engineering of a State of The Art Cardinality Estimation + * Algorithm + * http://static.googleusercontent.com/external_content/untrusted_dlcp/research.google.com/en/us/pubs/archive/40671.pdf + * + * Appendix to HyperLogLog in Practice: Algorithmic Engineering of a State of the Art Cardinality + * Estimation Algorithm + * https://docs.google.com/document/d/1gyjfMHy43U9OWBXxfaeG-3MjGzejW1dlpyMwEYAAWEI/view?fullscreen# + * + * @param child to estimate the cardinality of. + * @param relativeSD the maximum estimation error allowed. + */ +// scalastyle:on +case class HyperLogLogPlusPlus( + child: Expression, + relativeSD: Double = 0.05, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { + import HyperLogLogPlusPlus._ + + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + /** + * HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the + * algorithm will be, and the more memory it will require. The 'p' value is based on the relative + * error requested. + * + * HLL++ requires that we use at least 4 bits of addressing space (a minimum precision of 27%). + * + * This method rounds up to the nearest integer. This means that the error is always equal to or + * lower than the requested error. Use the trueRsd method to get the actual RSD + * value. + */ + private[this] val p = Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt + + require(p >= 4, "HLL++ requires at least 4 bits for addressing. " + + "Use a lower error, at most 27%.") + + /** + * Shift used to extract the index of the register from the hashed value. + * + * This assumes the use of 64-bit hashcodes. + */ + private[this] val idxShift = JLong.SIZE - p + + /** + * Value to pad the 'w' value with before the number of leading zeros is determined. + */ + private[this] val wPadding = 1L << (p - 1) + + /** + * The number of registers used. + */ + private[this] val m = 1 << p + + /** + * The pre-calculated combination of: alpha * m * m + * + * 'alpha' corrects the raw cardinality estimate 'Z'. See the FlFuGaMe07 paper for its + * derivation. + */ + private[this] val alphaM2 = p match { + case 4 => 0.673d * m * m + case 5 => 0.697d * m * m + case 6 => 0.709d * m * m + case _ => (0.7213d / (1.0d + 1.079d / m)) * m * m + } + + /** + * The number of words used to store the registers. We use Longs for storage because this is the + * most compact way of storage; Spark aligns to 8-byte words or uses Long wrappers. + * + * We only store whole registers per word in order to prevent overly complex bitwise operations. + * In practice this means we only use 60 out of 64 bits. + */ + private[this] val numWords = m / REGISTERS_PER_WORD + 1 + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** Allocate enough words to store all registers. */ + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i => + AttributeReference(s"MS[$i]", LongType)() + } + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + /** Fill all words with zeros. */ + override def initialize(buffer: MutableRow): Unit = { + var word = 0 + while (word < numWords) { + buffer.setLong(mutableAggBufferOffset + word, 0) + word += 1 + } + } + + /** + * Update the HLL++ buffer. + * + * Variable names in the HLL++ paper match variable names in the code. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = child.eval(input) + if (v != null) { + // Create the hashed value 'x'. + val x = MurmurHash.hash64(v) + + // Determine the index of the register we are going to use. + val idx = (x >>> idxShift).toInt + + // Determine the number of leading zeros in the remaining bits 'w'. + val pw = JLong.numberOfLeadingZeros((x << p) | wPadding) + 1L + + // Get the word containing the register we are interested in. + val wordOffset = idx / REGISTERS_PER_WORD + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) + + // Extract the M[J] register value from the word. + val shift = REGISTER_SIZE * (idx - (wordOffset * REGISTERS_PER_WORD)) + val mask = REGISTER_WORD_MASK << shift + val Midx = (word & mask) >>> shift + + // Assign the maximum number of leading zeros to the register. + if (pw > Midx) { + buffer.setLong(mutableAggBufferOffset + wordOffset, (word & ~mask) | (pw << shift)) + } + } + } + + /** + * Merge the HLL buffers by iterating through the registers in both buffers and select the + * maximum number of leading zeros for each register. + */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + var idx = 0 + var wordOffset = 0 + while (wordOffset < numWords) { + val word1 = buffer1.getLong(mutableAggBufferOffset + wordOffset) + val word2 = buffer2.getLong(inputAggBufferOffset + wordOffset) + var word = 0L + var i = 0 + var mask = REGISTER_WORD_MASK + while (idx < m && i < REGISTERS_PER_WORD) { + word |= Math.max(word1 & mask, word2 & mask) + mask <<= REGISTER_SIZE + i += 1 + idx += 1 + } + buffer1.setLong(mutableAggBufferOffset + wordOffset, word) + wordOffset += 1 + } + } + + /** + * Estimate the bias using the raw estimates with their respective biases from the HLL++ + * appendix. We currently use KNN interpolation to determine the bias (as suggested in the + * paper). + */ + def estimateBias(e: Double): Double = { + val estimates = RAW_ESTIMATE_DATA(p - 4) + val numEstimates = estimates.length + + // The estimates are sorted so we can use a binary search to find the index of the + // interpolation estimate closest to the current estimate. + val nearestEstimateIndex = util.Arrays.binarySearch(estimates, 0, numEstimates, e) match { + case ix if ix < 0 => -(ix + 1) + case ix => ix + } + + // Use square of the difference between the current estimate and the estimate at the given + // index as distance metric. + def distance(i: Int): Double = { + val diff = e - estimates(i) + diff * diff + } + + // Keep moving bounds as long as the the (exclusive) high bound is closer to the estimate than + // the lower (inclusive) bound. + var low = math.max(nearestEstimateIndex - K + 1, 0) + var high = math.min(low + K, numEstimates) + while (high < numEstimates && distance(high) < distance(low)) { + low += 1 + high += 1 + } + + // Calculate the sum of the biases in low-high interval. + val biases = BIAS_DATA(p - 4) + var i = low + var biasSum = 0.0 + while (i < high) { + biasSum += biases(i) + i += 1 + } + + // Calculate the bias. + biasSum / (high - low) + } + + /** + * Compute the HyperLogLog estimate. + * + * Variable names in the HLL++ paper match variable names in the code. + */ + override def eval(buffer: InternalRow): Any = { + // Compute the inverse of indicator value 'z' and count the number of zeros 'V'. + var zInverse = 0.0d + var V = 0.0d + var idx = 0 + var wordOffset = 0 + while (wordOffset < numWords) { + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) + var i = 0 + var shift = 0 + while (idx < m && i < REGISTERS_PER_WORD) { + val Midx = (word >>> shift) & REGISTER_WORD_MASK + zInverse += 1.0 / (1 << Midx) + if (Midx == 0) { + V += 1.0d + } + shift += REGISTER_SIZE + i += 1 + idx += 1 + } + wordOffset += 1 + } + + // We integrate two steps from the paper: + // val Z = 1.0d / zInverse + // val E = alphaM2 * Z + @inline + def EBiasCorrected = alphaM2 / zInverse match { + case e if p < 19 && e < 5.0d * m => e - estimateBias(e) + case e => e + } + + // Estimate the cardinality. + val estimate = if (V > 0) { + // Use linear counting for small cardinality estimates. + val H = m * Math.log(m / V) + if (H <= THRESHOLDS(p - 4)) { + H + } else { + EBiasCorrected + } + } else { + EBiasCorrected + } + + // Round to the nearest long value. + Math.round(estimate) + } + + /** + * The rsd of HLL++ is always equal to or better than the rsd requested. + * This method returns the rsd this instance actually guarantees. + * + * @return the actual rsd. + */ + def trueRsd: Double = 1.04 / math.sqrt(m) +} + +// scalastyle:off +/** + * Constants used in the implementation of the HyperLogLogPlusPlus aggregate function. + * + * See the Appendix to HyperLogLog in Practice: Algorithmic Engineering of a State of the Art + * Cardinality (https://docs.google.com/document/d/1gyjfMHy43U9OWBXxfaeG-3MjGzejW1dlpyMwEYAAWEI/view?fullscreen) + * for more information. + */ +// scalastyle:on +object HyperLogLogPlusPlus { + /** + * The size of a word used for storing registers: 64 Bits. + */ + val WORD_SIZE = java.lang.Long.SIZE + + /** + * The number of bits that is required per register. + * + * This number is determined by the maximum number of leading binary zeros a hashcode can + * produce. This is equal to the number of bits the hashcode returns. The current + * implementation uses a 64-bit hashcode, this means 6-bits are (at most) needed to store the + * number of leading zeros. + */ + val REGISTER_SIZE = 6 + + /** + * Value used to mask a register stored in a word. + */ + val REGISTER_WORD_MASK: Long = (1 << REGISTER_SIZE) - 1 + + /** + * The number of registers which can be stored in one word. + */ + val REGISTERS_PER_WORD = WORD_SIZE / REGISTER_SIZE + + /** + * Number of points used for interpolating the bias value. + */ + val K = 6 + + // Sacrificing style for readability here. + // scalastyle:off + + /** + * Thresholds which decide if the linear counting or the regular algorithm is used. + */ + val THRESHOLDS = Array[Double](10, 20, 40, 80, 220, 400, 900, 1800, 3100, 6500, 15500, 20000, 50000, 120000, 350000) + + /** + * Lookup table used to find the (index of the) bias correction for a given precision (exact) + * and estimate (nearest). + */ + val RAW_ESTIMATE_DATA = Array( + // precision 4 + Array(11, 11.717, 12.207, 12.7896, 13.2882, 13.8204, 14.3772, 14.9342, 15.5202, 16.161, 16.7722, 17.4636, 18.0396, 18.6766, 19.3566, 20.0454, 20.7936, 21.4856, 22.2666, 22.9946, 23.766, 24.4692, 25.3638, 26.0764, 26.7864, 27.7602, 28.4814, 29.433, 30.2926, 31.0664, 31.9996, 32.7956, 33.5366, 34.5894, 35.5738, 36.2698, 37.3682, 38.0544, 39.2342, 40.0108, 40.7966, 41.9298, 42.8704, 43.6358, 44.5194, 45.773, 46.6772, 47.6174, 48.4888, 49.3304, 50.2506, 51.4996, 52.3824, 53.3078, 54.3984, 55.5838, 56.6618, 57.2174, 58.3514, 59.0802, 60.1482, 61.0376, 62.3598, 62.8078, 63.9744, 64.914, 65.781, 67.1806, 68.0594, 68.8446, 69.7928, 70.8248, 71.8324, 72.8598, 73.6246, 74.7014, 75.393, 76.6708, 77.2394), + // precision 5 + Array(23, 23.1194, 23.8208, 24.2318, 24.77, 25.2436, 25.7774, 26.2848, 26.8224, 27.3742, 27.9336, 28.503, 29.0494, 29.6292, 30.2124, 30.798, 31.367, 31.9728, 32.5944, 33.217, 33.8438, 34.3696, 35.0956, 35.7044, 36.324, 37.0668, 37.6698, 38.3644, 39.049, 39.6918, 40.4146, 41.082, 41.687, 42.5398, 43.2462, 43.857, 44.6606, 45.4168, 46.1248, 46.9222, 47.6804, 48.447, 49.3454, 49.9594, 50.7636, 51.5776, 52.331, 53.19, 53.9676, 54.7564, 55.5314, 56.4442, 57.3708, 57.9774, 58.9624, 59.8796, 60.755, 61.472, 62.2076, 63.1024, 63.8908, 64.7338, 65.7728, 66.629, 67.413, 68.3266, 69.1524, 70.2642, 71.1806, 72.0566, 72.9192, 73.7598, 74.3516, 75.5802, 76.4386, 77.4916, 78.1524, 79.1892, 79.8414, 80.8798, 81.8376, 82.4698, 83.7656, 84.331, 85.5914, 86.6012, 87.7016, 88.5582, 89.3394, 90.3544, 91.4912, 92.308, 93.3552, 93.9746, 95.2052, 95.727, 97.1322, 98.3944, 98.7588, 100.242, 101.1914, 102.2538, 102.8776, 103.6292, 105.1932, 105.9152, 107.0868, 107.6728, 108.7144, 110.3114, 110.8716, 111.245, 112.7908, 113.7064, 114.636, 115.7464, 116.1788, 117.7464, 118.4896, 119.6166, 120.5082, 121.7798, 122.9028, 123.4426, 124.8854, 125.705, 126.4652, 128.3464, 128.3462, 130.0398, 131.0342, 131.0042, 132.4766, 133.511, 134.7252, 135.425, 136.5172, 138.0572, 138.6694, 139.3712, 140.8598, 141.4594, 142.554, 143.4006, 144.7374, 146.1634, 146.8994, 147.605, 147.9304, 149.1636, 150.2468, 151.5876, 152.2096, 153.7032, 154.7146, 155.807, 156.9228, 157.0372, 158.5852), + // precision 6 + Array(46, 46.1902, 47.271, 47.8358, 48.8142, 49.2854, 50.317, 51.354, 51.8924, 52.9436, 53.4596, 54.5262, 55.6248, 56.1574, 57.2822, 57.837, 58.9636, 60.074, 60.7042, 61.7976, 62.4772, 63.6564, 64.7942, 65.5004, 66.686, 67.291, 68.5672, 69.8556, 70.4982, 71.8204, 72.4252, 73.7744, 75.0786, 75.8344, 77.0294, 77.8098, 79.0794, 80.5732, 81.1878, 82.5648, 83.2902, 84.6784, 85.3352, 86.8946, 88.3712, 89.0852, 90.499, 91.2686, 92.6844, 94.2234, 94.9732, 96.3356, 97.2286, 98.7262, 100.3284, 101.1048, 102.5962, 103.3562, 105.1272, 106.4184, 107.4974, 109.0822, 109.856, 111.48, 113.2834, 114.0208, 115.637, 116.5174, 118.0576, 119.7476, 120.427, 122.1326, 123.2372, 125.2788, 126.6776, 127.7926, 129.1952, 129.9564, 131.6454, 133.87, 134.5428, 136.2, 137.0294, 138.6278, 139.6782, 141.792, 143.3516, 144.2832, 146.0394, 147.0748, 148.4912, 150.849, 151.696, 153.5404, 154.073, 156.3714, 157.7216, 158.7328, 160.4208, 161.4184, 163.9424, 165.2772, 166.411, 168.1308, 168.769, 170.9258, 172.6828, 173.7502, 175.706, 176.3886, 179.0186, 180.4518, 181.927, 183.4172, 184.4114, 186.033, 188.5124, 189.5564, 191.6008, 192.4172, 193.8044, 194.997, 197.4548, 198.8948, 200.2346, 202.3086, 203.1548, 204.8842, 206.6508, 206.6772, 209.7254, 210.4752, 212.7228, 214.6614, 215.1676, 217.793, 218.0006, 219.9052, 221.66, 223.5588, 225.1636, 225.6882, 227.7126, 229.4502, 231.1978, 232.9756, 233.1654, 236.727, 238.1974, 237.7474, 241.1346, 242.3048, 244.1948, 245.3134, 246.879, 249.1204, 249.853, 252.6792, 253.857, 254.4486, 257.2362, 257.9534, 260.0286, 260.5632, 262.663, 264.723, 265.7566, 267.2566, 267.1624, 270.62, 272.8216, 273.2166, 275.2056, 276.2202, 278.3726, 280.3344, 281.9284, 283.9728, 284.1924, 286.4872, 287.587, 289.807, 291.1206, 292.769, 294.8708, 296.665, 297.1182, 299.4012, 300.6352, 302.1354, 304.1756, 306.1606, 307.3462, 308.5214, 309.4134, 310.8352, 313.9684, 315.837, 316.7796, 318.9858), + // precision 7 + Array(92, 93.4934, 94.9758, 96.4574, 97.9718, 99.4954, 101.5302, 103.0756, 104.6374, 106.1782, 107.7888, 109.9522, 111.592, 113.2532, 114.9086, 116.5938, 118.9474, 120.6796, 122.4394, 124.2176, 125.9768, 128.4214, 130.2528, 132.0102, 133.8658, 135.7278, 138.3044, 140.1316, 142.093, 144.0032, 145.9092, 148.6306, 150.5294, 152.5756, 154.6508, 156.662, 159.552, 161.3724, 163.617, 165.5754, 167.7872, 169.8444, 172.7988, 174.8606, 177.2118, 179.3566, 181.4476, 184.5882, 186.6816, 189.0824, 191.0258, 193.6048, 196.4436, 198.7274, 200.957, 203.147, 205.4364, 208.7592, 211.3386, 213.781, 215.8028, 218.656, 221.6544, 223.996, 226.4718, 229.1544, 231.6098, 234.5956, 237.0616, 239.5758, 242.4878, 244.5244, 248.2146, 250.724, 252.8722, 255.5198, 258.0414, 261.941, 264.9048, 266.87, 269.4304, 272.028, 274.4708, 278.37, 281.0624, 283.4668, 286.5532, 289.4352, 293.2564, 295.2744, 298.2118, 300.7472, 304.1456, 307.2928, 309.7504, 312.5528, 315.979, 318.2102, 322.1834, 324.3494, 327.325, 330.6614, 332.903, 337.2544, 339.9042, 343.215, 345.2864, 348.0814, 352.6764, 355.301, 357.139, 360.658, 363.1732, 366.5902, 369.9538, 373.0828, 375.922, 378.9902, 382.7328, 386.4538, 388.1136, 391.2234, 394.0878, 396.708, 401.1556, 404.1852, 406.6372, 409.6822, 412.7796, 416.6078, 418.4916, 422.131, 424.5376, 428.1988, 432.211, 434.4502, 438.5282, 440.912, 444.0448, 447.7432, 450.8524, 453.7988, 456.7858, 458.8868, 463.9886, 466.5064, 468.9124, 472.6616, 475.4682, 478.582, 481.304, 485.2738, 488.6894, 490.329, 496.106, 497.6908, 501.1374, 504.5322, 506.8848, 510.3324, 513.4512, 516.179, 520.4412, 522.6066, 526.167, 528.7794, 533.379, 536.067, 538.46, 542.9116, 545.692, 547.9546, 552.493, 555.2722, 557.335, 562.449, 564.2014, 569.0738, 571.0974, 574.8564, 578.2996, 581.409, 583.9704, 585.8098, 589.6528, 594.5998, 595.958, 600.068, 603.3278, 608.2016, 609.9632, 612.864, 615.43, 620.7794, 621.272, 625.8644, 629.206, 633.219, 634.5154, 638.6102), + // precision 8 + Array(184.2152, 187.2454, 190.2096, 193.6652, 196.6312, 199.6822, 203.249, 206.3296, 210.0038, 213.2074, 216.4612, 220.27, 223.5178, 227.4412, 230.8032, 234.1634, 238.1688, 241.6074, 245.6946, 249.2664, 252.8228, 257.0432, 260.6824, 264.9464, 268.6268, 272.2626, 276.8376, 280.4034, 284.8956, 288.8522, 292.7638, 297.3552, 301.3556, 305.7526, 309.9292, 313.8954, 318.8198, 322.7668, 327.298, 331.6688, 335.9466, 340.9746, 345.1672, 349.3474, 354.3028, 358.8912, 364.114, 368.4646, 372.9744, 378.4092, 382.6022, 387.843, 392.5684, 397.1652, 402.5426, 407.4152, 412.5388, 417.3592, 422.1366, 427.486, 432.3918, 437.5076, 442.509, 447.3834, 453.3498, 458.0668, 463.7346, 469.1228, 473.4528, 479.7, 484.644, 491.0518, 495.5774, 500.9068, 506.432, 512.1666, 517.434, 522.6644, 527.4894, 533.6312, 538.3804, 544.292, 550.5496, 556.0234, 562.8206, 566.6146, 572.4188, 579.117, 583.6762, 590.6576, 595.7864, 601.509, 607.5334, 612.9204, 619.772, 624.2924, 630.8654, 636.1836, 642.745, 649.1316, 655.0386, 660.0136, 666.6342, 671.6196, 678.1866, 684.4282, 689.3324, 695.4794, 702.5038, 708.129, 713.528, 720.3204, 726.463, 732.7928, 739.123, 744.7418, 751.2192, 756.5102, 762.6066, 769.0184, 775.2224, 781.4014, 787.7618, 794.1436, 798.6506, 805.6378, 811.766, 819.7514, 824.5776, 828.7322, 837.8048, 843.6302, 849.9336, 854.4798, 861.3388, 867.9894, 873.8196, 880.3136, 886.2308, 892.4588, 899.0816, 905.4076, 912.0064, 917.3878, 923.619, 929.998, 937.3482, 943.9506, 947.991, 955.1144, 962.203, 968.8222, 975.7324, 981.7826, 988.7666, 994.2648, 1000.3128, 1007.4082, 1013.7536, 1020.3376, 1026.7156, 1031.7478, 1037.4292, 1045.393, 1051.2278, 1058.3434, 1062.8726, 1071.884, 1076.806, 1082.9176, 1089.1678, 1095.5032, 1102.525, 1107.2264, 1115.315, 1120.93, 1127.252, 1134.1496, 1139.0408, 1147.5448, 1153.3296, 1158.1974, 1166.5262, 1174.3328, 1175.657, 1184.4222, 1190.9172, 1197.1292, 1204.4606, 1210.4578, 1218.8728, 1225.3336, 1226.6592, 1236.5768, 1241.363, 1249.4074, 1254.6566, 1260.8014, 1266.5454, 1274.5192), + // precision 9 + Array(369, 374.8294, 381.2452, 387.6698, 394.1464, 400.2024, 406.8782, 413.6598, 420.462, 427.2826, 433.7102, 440.7416, 447.9366, 455.1046, 462.285, 469.0668, 476.306, 483.8448, 491.301, 498.9886, 506.2422, 513.8138, 521.7074, 529.7428, 537.8402, 545.1664, 553.3534, 561.594, 569.6886, 577.7876, 585.65, 594.228, 602.8036, 611.1666, 620.0818, 628.0824, 637.2574, 646.302, 655.1644, 664.0056, 672.3802, 681.7192, 690.5234, 700.2084, 708.831, 718.485, 728.1112, 737.4764, 746.76, 756.3368, 766.5538, 775.5058, 785.2646, 795.5902, 804.3818, 814.8998, 824.9532, 835.2062, 845.2798, 854.4728, 864.9582, 875.3292, 886.171, 896.781, 906.5716, 916.7048, 927.5322, 937.875, 949.3972, 958.3464, 969.7274, 980.2834, 992.1444, 1003.4264, 1013.0166, 1024.018, 1035.0438, 1046.34, 1057.6856, 1068.9836, 1079.0312, 1091.677, 1102.3188, 1113.4846, 1124.4424, 1135.739, 1147.1488, 1158.9202, 1169.406, 1181.5342, 1193.2834, 1203.8954, 1216.3286, 1226.2146, 1239.6684, 1251.9946, 1262.123, 1275.4338, 1285.7378, 1296.076, 1308.9692, 1320.4964, 1333.0998, 1343.9864, 1357.7754, 1368.3208, 1380.4838, 1392.7388, 1406.0758, 1416.9098, 1428.9728, 1440.9228, 1453.9292, 1462.617, 1476.05, 1490.2996, 1500.6128, 1513.7392, 1524.5174, 1536.6322, 1548.2584, 1562.3766, 1572.423, 1587.1232, 1596.5164, 1610.5938, 1622.5972, 1633.1222, 1647.7674, 1658.5044, 1671.57, 1683.7044, 1695.4142, 1708.7102, 1720.6094, 1732.6522, 1747.841, 1756.4072, 1769.9786, 1782.3276, 1797.5216, 1808.3186, 1819.0694, 1834.354, 1844.575, 1856.2808, 1871.1288, 1880.7852, 1893.9622, 1906.3418, 1920.6548, 1932.9302, 1945.8584, 1955.473, 1968.8248, 1980.6446, 1995.9598, 2008.349, 2019.8556, 2033.0334, 2044.0206, 2059.3956, 2069.9174, 2082.6084, 2093.7036, 2106.6108, 2118.9124, 2132.301, 2144.7628, 2159.8422, 2171.0212, 2183.101, 2193.5112, 2208.052, 2221.3194, 2233.3282, 2247.295, 2257.7222, 2273.342, 2286.5638, 2299.6786, 2310.8114, 2322.3312, 2335.516, 2349.874, 2363.5968, 2373.865, 2387.1918, 2401.8328, 2414.8496, 2424.544, 2436.7592, 2447.1682, 2464.1958, 2474.3438, 2489.0006, 2497.4526, 2513.6586, 2527.19, 2540.7028, 2553.768), + // precision 10 + Array(738.1256, 750.4234, 763.1064, 775.4732, 788.4636, 801.0644, 814.488, 827.9654, 841.0832, 854.7864, 868.1992, 882.2176, 896.5228, 910.1716, 924.7752, 938.899, 953.6126, 968.6492, 982.9474, 998.5214, 1013.1064, 1028.6364, 1044.2468, 1059.4588, 1075.3832, 1091.0584, 1106.8606, 1123.3868, 1139.5062, 1156.1862, 1172.463, 1189.339, 1206.1936, 1223.1292, 1240.1854, 1257.2908, 1275.3324, 1292.8518, 1310.5204, 1328.4854, 1345.9318, 1364.552, 1381.4658, 1400.4256, 1419.849, 1438.152, 1456.8956, 1474.8792, 1494.118, 1513.62, 1532.5132, 1551.9322, 1570.7726, 1590.6086, 1610.5332, 1630.5918, 1650.4294, 1669.7662, 1690.4106, 1710.7338, 1730.9012, 1750.4486, 1770.1556, 1791.6338, 1812.7312, 1833.6264, 1853.9526, 1874.8742, 1896.8326, 1918.1966, 1939.5594, 1961.07, 1983.037, 2003.1804, 2026.071, 2047.4884, 2070.0848, 2091.2944, 2114.333, 2135.9626, 2158.2902, 2181.0814, 2202.0334, 2224.4832, 2246.39, 2269.7202, 2292.1714, 2314.2358, 2338.9346, 2360.891, 2384.0264, 2408.3834, 2430.1544, 2454.8684, 2476.9896, 2501.4368, 2522.8702, 2548.0408, 2570.6738, 2593.5208, 2617.0158, 2640.2302, 2664.0962, 2687.4986, 2714.2588, 2735.3914, 2759.6244, 2781.8378, 2808.0072, 2830.6516, 2856.2454, 2877.2136, 2903.4546, 2926.785, 2951.2294, 2976.468, 3000.867, 3023.6508, 3049.91, 3073.5984, 3098.162, 3121.5564, 3146.2328, 3170.9484, 3195.5902, 3221.3346, 3242.7032, 3271.6112, 3296.5546, 3317.7376, 3345.072, 3369.9518, 3394.326, 3418.1818, 3444.6926, 3469.086, 3494.2754, 3517.8698, 3544.248, 3565.3768, 3588.7234, 3616.979, 3643.7504, 3668.6812, 3695.72, 3719.7392, 3742.6224, 3770.4456, 3795.6602, 3819.9058, 3844.002, 3869.517, 3895.6824, 3920.8622, 3947.1364, 3973.985, 3995.4772, 4021.62, 4046.628, 4074.65, 4096.2256, 4121.831, 4146.6406, 4173.276, 4195.0744, 4223.9696, 4251.3708, 4272.9966, 4300.8046, 4326.302, 4353.1248, 4374.312, 4403.0322, 4426.819, 4450.0598, 4478.5206, 4504.8116, 4528.8928, 4553.9584, 4578.8712, 4603.8384, 4632.3872, 4655.5128, 4675.821, 4704.6222, 4731.9862, 4755.4174, 4781.2628, 4804.332, 4832.3048, 4862.8752, 4883.4148, 4906.9544, 4935.3516, 4954.3532, 4984.0248, 5011.217, 5035.3258, 5057.3672, 5084.1828), + // precision 11 + Array(1477, 1501.6014, 1526.5802, 1551.7942, 1577.3042, 1603.2062, 1629.8402, 1656.2292, 1682.9462, 1709.9926, 1737.3026, 1765.4252, 1793.0578, 1821.6092, 1849.626, 1878.5568, 1908.527, 1937.5154, 1967.1874, 1997.3878, 2027.37, 2058.1972, 2089.5728, 2120.1012, 2151.9668, 2183.292, 2216.0772, 2247.8578, 2280.6562, 2313.041, 2345.714, 2380.3112, 2414.1806, 2447.9854, 2481.656, 2516.346, 2551.5154, 2586.8378, 2621.7448, 2656.6722, 2693.5722, 2729.1462, 2765.4124, 2802.8728, 2838.898, 2876.408, 2913.4926, 2951.4938, 2989.6776, 3026.282, 3065.7704, 3104.1012, 3143.7388, 3181.6876, 3221.1872, 3261.5048, 3300.0214, 3339.806, 3381.409, 3421.4144, 3461.4294, 3502.2286, 3544.651, 3586.6156, 3627.337, 3670.083, 3711.1538, 3753.5094, 3797.01, 3838.6686, 3882.1678, 3922.8116, 3967.9978, 4009.9204, 4054.3286, 4097.5706, 4140.6014, 4185.544, 4229.5976, 4274.583, 4316.9438, 4361.672, 4406.2786, 4451.8628, 4496.1834, 4543.505, 4589.1816, 4632.5188, 4678.2294, 4724.8908, 4769.0194, 4817.052, 4861.4588, 4910.1596, 4956.4344, 5002.5238, 5048.13, 5093.6374, 5142.8162, 5187.7894, 5237.3984, 5285.6078, 5331.0858, 5379.1036, 5428.6258, 5474.6018, 5522.7618, 5571.5822, 5618.59, 5667.9992, 5714.88, 5763.454, 5808.6982, 5860.3644, 5910.2914, 5953.571, 6005.9232, 6055.1914, 6104.5882, 6154.5702, 6199.7036, 6251.1764, 6298.7596, 6350.0302, 6398.061, 6448.4694, 6495.933, 6548.0474, 6597.7166, 6646.9416, 6695.9208, 6742.6328, 6793.5276, 6842.1934, 6894.2372, 6945.3864, 6996.9228, 7044.2372, 7094.1374, 7142.2272, 7192.2942, 7238.8338, 7288.9006, 7344.0908, 7394.8544, 7443.5176, 7490.4148, 7542.9314, 7595.6738, 7641.9878, 7694.3688, 7743.0448, 7797.522, 7845.53, 7899.594, 7950.3132, 7996.455, 8050.9442, 8092.9114, 8153.1374, 8197.4472, 8252.8278, 8301.8728, 8348.6776, 8401.4698, 8453.551, 8504.6598, 8553.8944, 8604.1276, 8657.6514, 8710.3062, 8758.908, 8807.8706, 8862.1702, 8910.4668, 8960.77, 9007.2766, 9063.164, 9121.0534, 9164.1354, 9218.1594, 9267.767, 9319.0594, 9372.155, 9419.7126, 9474.3722, 9520.1338, 9572.368, 9622.7702, 9675.8448, 9726.5396, 9778.7378, 9827.6554, 9878.1922, 9928.7782, 9978.3984, 10026.578, 10076.5626, 10137.1618, 10177.5244, 10229.9176), + // precision 12 + Array(2954, 3003.4782, 3053.3568, 3104.3666, 3155.324, 3206.9598, 3259.648, 3312.539, 3366.1474, 3420.2576, 3474.8376, 3530.6076, 3586.451, 3643.38, 3700.4104, 3757.5638, 3815.9676, 3875.193, 3934.838, 3994.8548, 4055.018, 4117.1742, 4178.4482, 4241.1294, 4304.4776, 4367.4044, 4431.8724, 4496.3732, 4561.4304, 4627.5326, 4693.949, 4761.5532, 4828.7256, 4897.6182, 4965.5186, 5034.4528, 5104.865, 5174.7164, 5244.6828, 5316.6708, 5387.8312, 5459.9036, 5532.476, 5604.8652, 5679.6718, 5753.757, 5830.2072, 5905.2828, 5980.0434, 6056.6264, 6134.3192, 6211.5746, 6290.0816, 6367.1176, 6447.9796, 6526.5576, 6606.1858, 6686.9144, 6766.1142, 6847.0818, 6927.9664, 7010.9096, 7091.0816, 7175.3962, 7260.3454, 7344.018, 7426.4214, 7511.3106, 7596.0686, 7679.8094, 7765.818, 7852.4248, 7936.834, 8022.363, 8109.5066, 8200.4554, 8288.5832, 8373.366, 8463.4808, 8549.7682, 8642.0522, 8728.3288, 8820.9528, 8907.727, 9001.0794, 9091.2522, 9179.988, 9269.852, 9362.6394, 9453.642, 9546.9024, 9640.6616, 9732.6622, 9824.3254, 9917.7484, 10007.9392, 10106.7508, 10196.2152, 10289.8114, 10383.5494, 10482.3064, 10576.8734, 10668.7872, 10764.7156, 10862.0196, 10952.793, 11049.9748, 11146.0702, 11241.4492, 11339.2772, 11434.2336, 11530.741, 11627.6136, 11726.311, 11821.5964, 11918.837, 12015.3724, 12113.0162, 12213.0424, 12306.9804, 12408.4518, 12504.8968, 12604.586, 12700.9332, 12798.705, 12898.5142, 12997.0488, 13094.788, 13198.475, 13292.7764, 13392.9698, 13486.8574, 13590.1616, 13686.5838, 13783.6264, 13887.2638, 13992.0978, 14081.0844, 14189.9956, 14280.0912, 14382.4956, 14486.4384, 14588.1082, 14686.2392, 14782.276, 14888.0284, 14985.1864, 15088.8596, 15187.0998, 15285.027, 15383.6694, 15495.8266, 15591.3736, 15694.2008, 15790.3246, 15898.4116, 15997.4522, 16095.5014, 16198.8514, 16291.7492, 16402.6424, 16499.1266, 16606.2436, 16697.7186, 16796.3946, 16902.3376, 17005.7672, 17100.814, 17206.8282, 17305.8262, 17416.0744, 17508.4092, 17617.0178, 17715.4554, 17816.758, 17920.1748, 18012.9236, 18119.7984, 18223.2248, 18324.2482, 18426.6276, 18525.0932, 18629.8976, 18733.2588, 18831.0466, 18940.1366, 19032.2696, 19131.729, 19243.4864, 19349.6932, 19442.866, 19547.9448, 19653.2798, 19754.4034, 19854.0692, 19965.1224, 20065.1774, 20158.2212, 20253.353, 20366.3264, 20463.22), + // precision 13 + Array(5908.5052, 6007.2672, 6107.347, 6208.5794, 6311.2622, 6414.5514, 6519.3376, 6625.6952, 6732.5988, 6841.3552, 6950.5972, 7061.3082, 7173.5646, 7287.109, 7401.8216, 7516.4344, 7633.3802, 7751.2962, 7870.3784, 7990.292, 8110.79, 8233.4574, 8356.6036, 8482.2712, 8607.7708, 8735.099, 8863.1858, 8993.4746, 9123.8496, 9255.6794, 9388.5448, 9522.7516, 9657.3106, 9792.6094, 9930.5642, 10068.794, 10206.7256, 10347.81, 10490.3196, 10632.0778, 10775.9916, 10920.4662, 11066.124, 11213.073, 11358.0362, 11508.1006, 11659.1716, 11808.7514, 11959.4884, 12112.1314, 12265.037, 12420.3756, 12578.933, 12734.311, 12890.0006, 13047.2144, 13207.3096, 13368.5144, 13528.024, 13689.847, 13852.7528, 14018.3168, 14180.5372, 14346.9668, 14513.5074, 14677.867, 14846.2186, 15017.4186, 15184.9716, 15356.339, 15529.2972, 15697.3578, 15871.8686, 16042.187, 16216.4094, 16389.4188, 16565.9126, 16742.3272, 16919.0042, 17094.7592, 17273.965, 17451.8342, 17634.4254, 17810.5984, 17988.9242, 18171.051, 18354.7938, 18539.466, 18721.0408, 18904.9972, 19081.867, 19271.9118, 19451.8694, 19637.9816, 19821.2922, 20013.1292, 20199.3858, 20387.8726, 20572.9514, 20770.7764, 20955.1714, 21144.751, 21329.9952, 21520.709, 21712.7016, 21906.3868, 22096.2626, 22286.0524, 22475.051, 22665.5098, 22862.8492, 23055.5294, 23249.6138, 23437.848, 23636.273, 23826.093, 24020.3296, 24213.3896, 24411.7392, 24602.9614, 24805.7952, 24998.1552, 25193.9588, 25389.0166, 25585.8392, 25780.6976, 25981.2728, 26175.977, 26376.5252, 26570.1964, 26773.387, 26962.9812, 27163.0586, 27368.164, 27565.0534, 27758.7428, 27961.1276, 28163.2324, 28362.3816, 28565.7668, 28758.644, 28956.9768, 29163.4722, 29354.7026, 29561.1186, 29767.9948, 29959.9986, 30164.0492, 30366.9818, 30562.5338, 30762.9928, 30976.1592, 31166.274, 31376.722, 31570.3734, 31770.809, 31974.8934, 32179.5286, 32387.5442, 32582.3504, 32794.076, 32989.9528, 33191.842, 33392.4684, 33595.659, 33801.8672, 34000.3414, 34200.0922, 34402.6792, 34610.0638, 34804.0084, 35011.13, 35218.669, 35418.6634, 35619.0792, 35830.6534, 36028.4966, 36229.7902, 36438.6422, 36630.7764, 36833.3102, 37048.6728, 37247.3916, 37453.5904, 37669.3614, 37854.5526, 38059.305, 38268.0936, 38470.2516, 38674.7064, 38876.167, 39068.3794, 39281.9144, 39492.8566, 39684.8628, 39898.4108, 40093.1836, 40297.6858, 40489.7086, 40717.2424), + // precision 14 + Array(11817.475, 12015.0046, 12215.3792, 12417.7504, 12623.1814, 12830.0086, 13040.0072, 13252.503, 13466.178, 13683.2738, 13902.0344, 14123.9798, 14347.394, 14573.7784, 14802.6894, 15033.6824, 15266.9134, 15502.8624, 15741.4944, 15980.7956, 16223.8916, 16468.6316, 16715.733, 16965.5726, 17217.204, 17470.666, 17727.8516, 17986.7886, 18247.6902, 18510.9632, 18775.304, 19044.7486, 19314.4408, 19587.202, 19862.2576, 20135.924, 20417.0324, 20697.9788, 20979.6112, 21265.0274, 21550.723, 21841.6906, 22132.162, 22428.1406, 22722.127, 23020.5606, 23319.7394, 23620.4014, 23925.2728, 24226.9224, 24535.581, 24845.505, 25155.9618, 25470.3828, 25785.9702, 26103.7764, 26420.4132, 26742.0186, 27062.8852, 27388.415, 27714.6024, 28042.296, 28365.4494, 28701.1526, 29031.8008, 29364.2156, 29704.497, 30037.1458, 30380.111, 30723.8168, 31059.5114, 31404.9498, 31751.6752, 32095.2686, 32444.7792, 32794.767, 33145.204, 33498.4226, 33847.6502, 34209.006, 34560.849, 34919.4838, 35274.9778, 35635.1322, 35996.3266, 36359.1394, 36722.8266, 37082.8516, 37447.7354, 37815.9606, 38191.0692, 38559.4106, 38924.8112, 39294.6726, 39663.973, 40042.261, 40416.2036, 40779.2036, 41161.6436, 41540.9014, 41921.1998, 42294.7698, 42678.5264, 43061.3464, 43432.375, 43818.432, 44198.6598, 44583.0138, 44970.4794, 45353.924, 45729.858, 46118.2224, 46511.5724, 46900.7386, 47280.6964, 47668.1472, 48055.6796, 48446.9436, 48838.7146, 49217.7296, 49613.7796, 50010.7508, 50410.0208, 50793.7886, 51190.2456, 51583.1882, 51971.0796, 52376.5338, 52763.319, 53165.5534, 53556.5594, 53948.2702, 54346.352, 54748.7914, 55138.577, 55543.4824, 55941.1748, 56333.7746, 56745.1552, 57142.7944, 57545.2236, 57935.9956, 58348.5268, 58737.5474, 59158.5962, 59542.6896, 59958.8004, 60349.3788, 60755.0212, 61147.6144, 61548.194, 61946.0696, 62348.6042, 62763.603, 63162.781, 63560.635, 63974.3482, 64366.4908, 64771.5876, 65176.7346, 65597.3916, 65995.915, 66394.0384, 66822.9396, 67203.6336, 67612.2032, 68019.0078, 68420.0388, 68821.22, 69235.8388, 69640.0724, 70055.155, 70466.357, 70863.4266, 71276.2482, 71677.0306, 72080.2006, 72493.0214, 72893.5952, 73314.5856, 73714.9852, 74125.3022, 74521.2122, 74933.6814, 75341.5904, 75743.0244, 76166.0278, 76572.1322, 76973.1028, 77381.6284, 77800.6092, 78189.328, 78607.0962, 79012.2508, 79407.8358, 79825.725, 80238.701, 80646.891, 81035.6436, 81460.0448, 81876.3884), + // precision 15 + Array(23635.0036, 24030.8034, 24431.4744, 24837.1524, 25246.7928, 25661.326, 26081.3532, 26505.2806, 26933.9892, 27367.7098, 27805.318, 28248.799, 28696.4382, 29148.8244, 29605.5138, 30066.8668, 30534.2344, 31006.32, 31480.778, 31962.2418, 32447.3324, 32938.0232, 33432.731, 33930.728, 34433.9896, 34944.1402, 35457.5588, 35974.5958, 36497.3296, 37021.9096, 37554.326, 38088.0826, 38628.8816, 39171.3192, 39723.2326, 40274.5554, 40832.3142, 41390.613, 41959.5908, 42532.5466, 43102.0344, 43683.5072, 44266.694, 44851.2822, 45440.7862, 46038.0586, 46640.3164, 47241.064, 47846.155, 48454.7396, 49076.9168, 49692.542, 50317.4778, 50939.65, 51572.5596, 52210.2906, 52843.7396, 53481.3996, 54127.236, 54770.406, 55422.6598, 56078.7958, 56736.7174, 57397.6784, 58064.5784, 58730.308, 59404.9784, 60077.0864, 60751.9158, 61444.1386, 62115.817, 62808.7742, 63501.4774, 64187.5454, 64883.6622, 65582.7468, 66274.5318, 66976.9276, 67688.7764, 68402.138, 69109.6274, 69822.9706, 70543.6108, 71265.5202, 71983.3848, 72708.4656, 73433.384, 74158.4664, 74896.4868, 75620.9564, 76362.1434, 77098.3204, 77835.7662, 78582.6114, 79323.9902, 80067.8658, 80814.9246, 81567.0136, 82310.8536, 83061.9952, 83821.4096, 84580.8608, 85335.547, 86092.5802, 86851.6506, 87612.311, 88381.2016, 89146.3296, 89907.8974, 90676.846, 91451.4152, 92224.5518, 92995.8686, 93763.5066, 94551.2796, 95315.1944, 96096.1806, 96881.0918, 97665.679, 98442.68, 99229.3002, 100011.0994, 100790.6386, 101580.1564, 102377.7484, 103152.1392, 103944.2712, 104730.216, 105528.6336, 106324.9398, 107117.6706, 107890.3988, 108695.2266, 109485.238, 110294.7876, 111075.0958, 111878.0496, 112695.2864, 113464.5486, 114270.0474, 115068.608, 115884.3626, 116673.2588, 117483.3716, 118275.097, 119085.4092, 119879.2808, 120687.5868, 121499.9944, 122284.916, 123095.9254, 123912.5038, 124709.0454, 125503.7182, 126323.259, 127138.9412, 127943.8294, 128755.646, 129556.5354, 130375.3298, 131161.4734, 131971.1962, 132787.5458, 133588.1056, 134431.351, 135220.2906, 136023.398, 136846.6558, 137667.0004, 138463.663, 139283.7154, 140074.6146, 140901.3072, 141721.8548, 142543.2322, 143356.1096, 144173.7412, 144973.0948, 145794.3162, 146609.5714, 147420.003, 148237.9784, 149050.5696, 149854.761, 150663.1966, 151494.0754, 152313.1416, 153112.6902, 153935.7206, 154746.9262, 155559.547, 156401.9746, 157228.7036, 158008.7254, 158820.75, 159646.9184, 160470.4458, 161279.5348, 162093.3114, 162918.542, 163729.2842), + // precision 16 + Array(47271, 48062.3584, 48862.7074, 49673.152, 50492.8416, 51322.9514, 52161.03, 53009.407, 53867.6348, 54734.206, 55610.5144, 56496.2096, 57390.795, 58297.268, 59210.6448, 60134.665, 61068.0248, 62010.4472, 62962.5204, 63923.5742, 64895.0194, 65876.4182, 66862.6136, 67862.6968, 68868.8908, 69882.8544, 70911.271, 71944.0924, 72990.0326, 74040.692, 75100.6336, 76174.7826, 77252.5998, 78340.2974, 79438.2572, 80545.4976, 81657.2796, 82784.6336, 83915.515, 85059.7362, 86205.9368, 87364.4424, 88530.3358, 89707.3744, 90885.9638, 92080.197, 93275.5738, 94479.391, 95695.918, 96919.2236, 98148.4602, 99382.3474, 100625.6974, 101878.0284, 103141.6278, 104409.4588, 105686.2882, 106967.5402, 108261.6032, 109548.1578, 110852.0728, 112162.231, 113479.0072, 114806.2626, 116137.9072, 117469.5048, 118813.5186, 120165.4876, 121516.2556, 122875.766, 124250.5444, 125621.2222, 127003.2352, 128387.848, 129775.2644, 131181.7776, 132577.3086, 133979.9458, 135394.1132, 136800.9078, 138233.217, 139668.5308, 141085.212, 142535.2122, 143969.0684, 145420.2872, 146878.1542, 148332.7572, 149800.3202, 151269.66, 152743.6104, 154213.0948, 155690.288, 157169.4246, 158672.1756, 160160.059, 161650.6854, 163145.7772, 164645.6726, 166159.1952, 167682.1578, 169177.3328, 170700.0118, 172228.8964, 173732.6664, 175265.5556, 176787.799, 178317.111, 179856.6914, 181400.865, 182943.4612, 184486.742, 186033.4698, 187583.7886, 189148.1868, 190688.4526, 192250.1926, 193810.9042, 195354.2972, 196938.7682, 198493.5898, 200079.2824, 201618.912, 203205.5492, 204765.5798, 206356.1124, 207929.3064, 209498.7196, 211086.229, 212675.1324, 214256.7892, 215826.2392, 217412.8474, 218995.6724, 220618.6038, 222207.1166, 223781.0364, 225387.4332, 227005.7928, 228590.4336, 230217.8738, 231805.1054, 233408.9, 234995.3432, 236601.4956, 238190.7904, 239817.2548, 241411.2832, 243002.4066, 244640.1884, 246255.3128, 247849.3508, 249479.9734, 251106.8822, 252705.027, 254332.9242, 255935.129, 257526.9014, 259154.772, 260777.625, 262390.253, 264004.4906, 265643.59, 267255.4076, 268873.426, 270470.7252, 272106.4804, 273722.4456, 275337.794, 276945.7038, 278592.9154, 280204.3726, 281841.1606, 283489.171, 285130.1716, 286735.3362, 288364.7164, 289961.1814, 291595.5524, 293285.683, 294899.6668, 296499.3434, 298128.0462, 299761.8946, 301394.2424, 302997.6748, 304615.1478, 306269.7724, 307886.114, 309543.1028, 311153.2862, 312782.8546, 314421.2008, 316033.2438, 317692.9636, 319305.2648, 320948.7406, 322566.3364, 324228.4224, 325847.1542), + // precision 17 + Array(94542, 96125.811, 97728.019, 99348.558, 100987.9705, 102646.7565, 104324.5125, 106021.7435, 107736.7865, 109469.272, 111223.9465, 112995.219, 114787.432, 116593.152, 118422.71, 120267.2345, 122134.6765, 124020.937, 125927.2705, 127851.255, 129788.9485, 131751.016, 133726.8225, 135722.592, 137736.789, 139770.568, 141821.518, 143891.343, 145982.1415, 148095.387, 150207.526, 152355.649, 154515.6415, 156696.05, 158887.7575, 161098.159, 163329.852, 165569.053, 167837.4005, 170121.6165, 172420.4595, 174732.6265, 177062.77, 179412.502, 181774.035, 184151.939, 186551.6895, 188965.691, 191402.8095, 193857.949, 196305.0775, 198774.6715, 201271.2585, 203764.78, 206299.3695, 208818.1365, 211373.115, 213946.7465, 216532.076, 219105.541, 221714.5375, 224337.5135, 226977.5125, 229613.0655, 232270.2685, 234952.2065, 237645.3555, 240331.1925, 243034.517, 245756.0725, 248517.6865, 251232.737, 254011.3955, 256785.995, 259556.44, 262368.335, 265156.911, 267965.266, 270785.583, 273616.0495, 276487.4835, 279346.639, 282202.509, 285074.3885, 287942.2855, 290856.018, 293774.0345, 296678.5145, 299603.6355, 302552.6575, 305492.9785, 308466.8605, 311392.581, 314347.538, 317319.4295, 320285.9785, 323301.7325, 326298.3235, 329301.3105, 332301.987, 335309.791, 338370.762, 341382.923, 344431.1265, 347464.1545, 350507.28, 353619.2345, 356631.2005, 359685.203, 362776.7845, 365886.488, 368958.2255, 372060.6825, 375165.4335, 378237.935, 381328.311, 384430.5225, 387576.425, 390683.242, 393839.648, 396977.8425, 400101.9805, 403271.296, 406409.8425, 409529.5485, 412678.7, 415847.423, 419020.8035, 422157.081, 425337.749, 428479.6165, 431700.902, 434893.1915, 438049.582, 441210.5415, 444379.2545, 447577.356, 450741.931, 453959.548, 457137.0935, 460329.846, 463537.4815, 466732.3345, 469960.5615, 473164.681, 476347.6345, 479496.173, 482813.1645, 486025.6995, 489249.4885, 492460.1945, 495675.8805, 498908.0075, 502131.802, 505374.3855, 508550.9915, 511806.7305, 515026.776, 518217.0005, 521523.9855, 524705.9855, 527950.997, 531210.0265, 534472.497, 537750.7315, 540926.922, 544207.094, 547429.4345, 550666.3745, 553975.3475, 557150.7185, 560399.6165, 563662.697, 566916.7395, 570146.1215, 573447.425, 576689.6245, 579874.5745, 583202.337, 586503.0255, 589715.635, 592910.161, 596214.3885, 599488.035, 602740.92, 605983.0685, 609248.67, 612491.3605, 615787.912, 619107.5245, 622307.9555, 625577.333, 628840.4385, 632085.2155, 635317.6135, 638691.7195, 641887.467, 645139.9405, 648441.546, 651666.252, 654941.845), + // precision 18 + Array(189084, 192250.913, 195456.774, 198696.946, 201977.762, 205294.444, 208651.754, 212042.099, 215472.269, 218941.91, 222443.912, 225996.845, 229568.199, 233193.568, 236844.457, 240543.233, 244279.475, 248044.27, 251854.588, 255693.2, 259583.619, 263494.621, 267445.385, 271454.061, 275468.769, 279549.456, 283646.446, 287788.198, 291966.099, 296181.164, 300431.469, 304718.618, 309024.004, 313393.508, 317760.803, 322209.731, 326675.061, 331160.627, 335654.47, 340241.442, 344841.833, 349467.132, 354130.629, 358819.432, 363574.626, 368296.587, 373118.482, 377914.93, 382782.301, 387680.669, 392601.981, 397544.323, 402529.115, 407546.018, 412593.658, 417638.657, 422762.865, 427886.169, 433017.167, 438213.273, 443441.254, 448692.421, 453937.533, 459239.049, 464529.569, 469910.083, 475274.03, 480684.473, 486070.26, 491515.237, 496995.651, 502476.617, 507973.609, 513497.19, 519083.233, 524726.509, 530305.505, 535945.728, 541584.404, 547274.055, 552967.236, 558667.862, 564360.216, 570128.148, 575965.08, 581701.952, 587532.523, 593361.144, 599246.128, 605033.418, 610958.779, 616837.117, 622772.818, 628672.04, 634675.369, 640574.831, 646585.739, 652574.547, 658611.217, 664642.684, 670713.914, 676737.681, 682797.313, 688837.897, 694917.874, 701009.882, 707173.648, 713257.254, 719415.392, 725636.761, 731710.697, 737906.209, 744103.074, 750313.39, 756504.185, 762712.579, 768876.985, 775167.859, 781359, 787615.959, 793863.597, 800245.477, 806464.582, 812785.294, 819005.925, 825403.057, 831676.197, 837936.284, 844266.968, 850642.711, 856959.756, 863322.774, 869699.931, 876102.478, 882355.787, 888694.463, 895159.952, 901536.143, 907872.631, 914293.672, 920615.14, 927130.974, 933409.404, 939922.178, 946331.47, 952745.93, 959209.264, 965590.224, 972077.284, 978501.961, 984953.19, 991413.271, 997817.479, 1004222.658, 1010725.676, 1017177.138, 1023612.529, 1030098.236, 1036493.719, 1043112.207, 1049537.036, 1056008.096, 1062476.184, 1068942.337, 1075524.95, 1081932.864, 1088426.025, 1094776.005, 1101327.448, 1107901.673, 1114423.639, 1120884.602, 1127324.923, 1133794.24, 1140328.886, 1146849.376, 1153346.682, 1159836.502, 1166478.703, 1172953.304, 1179391.502, 1185950.982, 1192544.052, 1198913.41, 1205430.994, 1212015.525, 1218674.042, 1225121.683, 1231551.101, 1238126.379, 1244673.795, 1251260.649, 1257697.86, 1264320.983, 1270736.319, 1277274.694, 1283804.95, 1290211.514, 1296858.568, 1303455.691) + ) + + /** + * Bias corrections given a precision and the index of the raw estimate table. + */ + val BIAS_DATA = Array( + // precision 4 + Array(10, 9.717, 9.207, 8.7896, 8.2882, 7.8204, 7.3772, 6.9342, 6.5202, 6.161, 5.7722, 5.4636, 5.0396, 4.6766, 4.3566, 4.0454, 3.7936, 3.4856, 3.2666, 2.9946, 2.766, 2.4692, 2.3638, 2.0764, 1.7864, 1.7602, 1.4814, 1.433, 1.2926, 1.0664, 0.999600000000001, 0.7956, 0.5366, 0.589399999999998, 0.573799999999999, 0.269799999999996, 0.368200000000002, 0.0544000000000011, 0.234200000000001, 0.0108000000000033, -0.203400000000002, -0.0701999999999998, -0.129600000000003, -0.364199999999997, -0.480600000000003, -0.226999999999997, -0.322800000000001, -0.382599999999996, -0.511200000000002, -0.669600000000003, -0.749400000000001, -0.500399999999999, -0.617600000000003, -0.6922, -0.601599999999998, -0.416200000000003, -0.338200000000001, -0.782600000000002, -0.648600000000002, -0.919800000000002, -0.851799999999997, -0.962400000000002, -0.6402, -1.1922, -1.0256, -1.086, -1.21899999999999, -0.819400000000002, -0.940600000000003, -1.1554, -1.2072, -1.1752, -1.16759999999999, -1.14019999999999, -1.3754, -1.29859999999999, -1.607, -1.3292, -1.7606), + // precision 5 + Array(22, 21.1194, 20.8208, 20.2318, 19.77, 19.2436, 18.7774, 18.2848, 17.8224, 17.3742, 16.9336, 16.503, 16.0494, 15.6292, 15.2124, 14.798, 14.367, 13.9728, 13.5944, 13.217, 12.8438, 12.3696, 12.0956, 11.7044, 11.324, 11.0668, 10.6698, 10.3644, 10.049, 9.6918, 9.4146, 9.082, 8.687, 8.5398, 8.2462, 7.857, 7.6606, 7.4168, 7.1248, 6.9222, 6.6804, 6.447, 6.3454, 5.9594, 5.7636, 5.5776, 5.331, 5.19, 4.9676, 4.7564, 4.5314, 4.4442, 4.3708, 3.9774, 3.9624, 3.8796, 3.755, 3.472, 3.2076, 3.1024, 2.8908, 2.7338, 2.7728, 2.629, 2.413, 2.3266, 2.1524, 2.2642, 2.1806, 2.0566, 1.9192, 1.7598, 1.3516, 1.5802, 1.43859999999999, 1.49160000000001, 1.1524, 1.1892, 0.841399999999993, 0.879800000000003, 0.837599999999995, 0.469800000000006, 0.765600000000006, 0.331000000000003, 0.591399999999993, 0.601200000000006, 0.701599999999999, 0.558199999999999, 0.339399999999998, 0.354399999999998, 0.491200000000006, 0.308000000000007, 0.355199999999996, -0.0254000000000048, 0.205200000000005, -0.272999999999996, 0.132199999999997, 0.394400000000005, -0.241200000000006, 0.242000000000004, 0.191400000000002, 0.253799999999998, -0.122399999999999, -0.370800000000003, 0.193200000000004, -0.0848000000000013, 0.0867999999999967, -0.327200000000005, -0.285600000000002, 0.311400000000006, -0.128399999999999, -0.754999999999995, -0.209199999999996, -0.293599999999998, -0.364000000000004, -0.253600000000006, -0.821200000000005, -0.253600000000006, -0.510400000000004, -0.383399999999995, -0.491799999999998, -0.220200000000006, -0.0972000000000008, -0.557400000000001, -0.114599999999996, -0.295000000000002, -0.534800000000004, 0.346399999999988, -0.65379999999999, 0.0398000000000138, 0.0341999999999985, -0.995800000000003, -0.523400000000009, -0.489000000000004, -0.274799999999999, -0.574999999999989, -0.482799999999997, 0.0571999999999946, -0.330600000000004, -0.628800000000012, -0.140199999999993, -0.540600000000012, -0.445999999999998, -0.599400000000003, -0.262599999999992, 0.163399999999996, -0.100599999999986, -0.39500000000001, -1.06960000000001, -0.836399999999998, -0.753199999999993, -0.412399999999991, -0.790400000000005, -0.29679999999999, -0.28540000000001, -0.193000000000012, -0.0772000000000048, -0.962799999999987, -0.414800000000014), + // precision 6 + Array(45, 44.1902, 43.271, 42.8358, 41.8142, 41.2854, 40.317, 39.354, 38.8924, 37.9436, 37.4596, 36.5262, 35.6248, 35.1574, 34.2822, 33.837, 32.9636, 32.074, 31.7042, 30.7976, 30.4772, 29.6564, 28.7942, 28.5004, 27.686, 27.291, 26.5672, 25.8556, 25.4982, 24.8204, 24.4252, 23.7744, 23.0786, 22.8344, 22.0294, 21.8098, 21.0794, 20.5732, 20.1878, 19.5648, 19.2902, 18.6784, 18.3352, 17.8946, 17.3712, 17.0852, 16.499, 16.2686, 15.6844, 15.2234, 14.9732, 14.3356, 14.2286, 13.7262, 13.3284, 13.1048, 12.5962, 12.3562, 12.1272, 11.4184, 11.4974, 11.0822, 10.856, 10.48, 10.2834, 10.0208, 9.637, 9.51739999999999, 9.05759999999999, 8.74760000000001, 8.42700000000001, 8.1326, 8.2372, 8.2788, 7.6776, 7.79259999999999, 7.1952, 6.9564, 6.6454, 6.87, 6.5428, 6.19999999999999, 6.02940000000001, 5.62780000000001, 5.6782, 5.792, 5.35159999999999, 5.28319999999999, 5.0394, 5.07480000000001, 4.49119999999999, 4.84899999999999, 4.696, 4.54040000000001, 4.07300000000001, 4.37139999999999, 3.7216, 3.7328, 3.42080000000001, 3.41839999999999, 3.94239999999999, 3.27719999999999, 3.411, 3.13079999999999, 2.76900000000001, 2.92580000000001, 2.68279999999999, 2.75020000000001, 2.70599999999999, 2.3886, 3.01859999999999, 2.45179999999999, 2.92699999999999, 2.41720000000001, 2.41139999999999, 2.03299999999999, 2.51240000000001, 2.5564, 2.60079999999999, 2.41720000000001, 1.80439999999999, 1.99700000000001, 2.45480000000001, 1.8948, 2.2346, 2.30860000000001, 2.15479999999999, 1.88419999999999, 1.6508, 0.677199999999999, 1.72540000000001, 1.4752, 1.72280000000001, 1.66139999999999, 1.16759999999999, 1.79300000000001, 1.00059999999999, 0.905200000000008, 0.659999999999997, 1.55879999999999, 1.1636, 0.688199999999995, 0.712600000000009, 0.450199999999995, 1.1978, 0.975599999999986, 0.165400000000005, 1.727, 1.19739999999999, -0.252600000000001, 1.13460000000001, 1.3048, 1.19479999999999, 0.313400000000001, 0.878999999999991, 1.12039999999999, 0.853000000000009, 1.67920000000001, 0.856999999999999, 0.448599999999999, 1.2362, 0.953399999999988, 1.02859999999998, 0.563199999999995, 0.663000000000011, 0.723000000000013, 0.756599999999992, 0.256599999999992, -0.837600000000009, 0.620000000000005, 0.821599999999989, 0.216600000000028, 0.205600000000004, 0.220199999999977, 0.372599999999977, 0.334400000000016, 0.928400000000011, 0.972800000000007, 0.192400000000021, 0.487199999999973, -0.413000000000011, 0.807000000000016, 0.120600000000024, 0.769000000000005, 0.870799999999974, 0.66500000000002, 0.118200000000002, 0.401200000000017, 0.635199999999998, 0.135400000000004, 0.175599999999974, 1.16059999999999, 0.34620000000001, 0.521400000000028, -0.586599999999976, -1.16480000000001, 0.968399999999974, 0.836999999999989, 0.779600000000016, 0.985799999999983), + // precision 7 + Array(91, 89.4934, 87.9758, 86.4574, 84.9718, 83.4954, 81.5302, 80.0756, 78.6374, 77.1782, 75.7888, 73.9522, 72.592, 71.2532, 69.9086, 68.5938, 66.9474, 65.6796, 64.4394, 63.2176, 61.9768, 60.4214, 59.2528, 58.0102, 56.8658, 55.7278, 54.3044, 53.1316, 52.093, 51.0032, 49.9092, 48.6306, 47.5294, 46.5756, 45.6508, 44.662, 43.552, 42.3724, 41.617, 40.5754, 39.7872, 38.8444, 37.7988, 36.8606, 36.2118, 35.3566, 34.4476, 33.5882, 32.6816, 32.0824, 31.0258, 30.6048, 29.4436, 28.7274, 27.957, 27.147, 26.4364, 25.7592, 25.3386, 24.781, 23.8028, 23.656, 22.6544, 21.996, 21.4718, 21.1544, 20.6098, 19.5956, 19.0616, 18.5758, 18.4878, 17.5244, 17.2146, 16.724, 15.8722, 15.5198, 15.0414, 14.941, 14.9048, 13.87, 13.4304, 13.028, 12.4708, 12.37, 12.0624, 11.4668, 11.5532, 11.4352, 11.2564, 10.2744, 10.2118, 9.74720000000002, 10.1456, 9.2928, 8.75040000000001, 8.55279999999999, 8.97899999999998, 8.21019999999999, 8.18340000000001, 7.3494, 7.32499999999999, 7.66140000000001, 6.90300000000002, 7.25439999999998, 6.9042, 7.21499999999997, 6.28640000000001, 6.08139999999997, 6.6764, 6.30099999999999, 5.13900000000001, 5.65800000000002, 5.17320000000001, 4.59019999999998, 4.9538, 5.08280000000002, 4.92200000000003, 4.99020000000002, 4.7328, 5.4538, 4.11360000000002, 4.22340000000003, 4.08780000000002, 3.70800000000003, 4.15559999999999, 4.18520000000001, 3.63720000000001, 3.68220000000002, 3.77960000000002, 3.6078, 2.49160000000001, 3.13099999999997, 2.5376, 3.19880000000001, 3.21100000000001, 2.4502, 3.52820000000003, 2.91199999999998, 3.04480000000001, 2.7432, 2.85239999999999, 2.79880000000003, 2.78579999999999, 1.88679999999999, 2.98860000000002, 2.50639999999999, 1.91239999999999, 2.66160000000002, 2.46820000000002, 1.58199999999999, 1.30399999999997, 2.27379999999999, 2.68939999999998, 1.32900000000001, 3.10599999999999, 1.69080000000002, 2.13740000000001, 2.53219999999999, 1.88479999999998, 1.33240000000001, 1.45119999999997, 1.17899999999997, 2.44119999999998, 1.60659999999996, 2.16700000000003, 0.77940000000001, 2.37900000000002, 2.06700000000001, 1.46000000000004, 2.91160000000002, 1.69200000000001, 0.954600000000028, 2.49300000000005, 2.2722, 1.33500000000004, 2.44899999999996, 1.20140000000004, 3.07380000000001, 2.09739999999999, 2.85640000000001, 2.29960000000005, 2.40899999999999, 1.97040000000004, 0.809799999999996, 1.65279999999996, 2.59979999999996, 0.95799999999997, 2.06799999999998, 2.32780000000002, 4.20159999999998, 1.96320000000003, 1.86400000000003, 1.42999999999995, 3.77940000000001, 1.27200000000005, 1.86440000000005, 2.20600000000002, 3.21900000000005, 1.5154, 2.61019999999996), + // precision 8 + Array(183.2152, 180.2454, 177.2096, 173.6652, 170.6312, 167.6822, 164.249, 161.3296, 158.0038, 155.2074, 152.4612, 149.27, 146.5178, 143.4412, 140.8032, 138.1634, 135.1688, 132.6074, 129.6946, 127.2664, 124.8228, 122.0432, 119.6824, 116.9464, 114.6268, 112.2626, 109.8376, 107.4034, 104.8956, 102.8522, 100.7638, 98.3552, 96.3556, 93.7526, 91.9292, 89.8954, 87.8198, 85.7668, 83.298, 81.6688, 79.9466, 77.9746, 76.1672, 74.3474, 72.3028, 70.8912, 69.114, 67.4646, 65.9744, 64.4092, 62.6022, 60.843, 59.5684, 58.1652, 56.5426, 55.4152, 53.5388, 52.3592, 51.1366, 49.486, 48.3918, 46.5076, 45.509, 44.3834, 43.3498, 42.0668, 40.7346, 40.1228, 38.4528, 37.7, 36.644, 36.0518, 34.5774, 33.9068, 32.432, 32.1666, 30.434, 29.6644, 28.4894, 27.6312, 26.3804, 26.292, 25.5496000000001, 25.0234, 24.8206, 22.6146, 22.4188, 22.117, 20.6762, 20.6576, 19.7864, 19.509, 18.5334, 17.9204, 17.772, 16.2924, 16.8654, 15.1836, 15.745, 15.1316, 15.0386, 14.0136, 13.6342, 12.6196, 12.1866, 12.4281999999999, 11.3324, 10.4794000000001, 11.5038, 10.129, 9.52800000000002, 10.3203999999999, 9.46299999999997, 9.79280000000006, 9.12300000000005, 8.74180000000001, 9.2192, 7.51020000000005, 7.60659999999996, 7.01840000000004, 7.22239999999999, 7.40139999999997, 6.76179999999999, 7.14359999999999, 5.65060000000005, 5.63779999999997, 5.76599999999996, 6.75139999999999, 5.57759999999996, 3.73220000000003, 5.8048, 5.63019999999995, 4.93359999999996, 3.47979999999995, 4.33879999999999, 3.98940000000005, 3.81960000000004, 3.31359999999995, 3.23080000000004, 3.4588, 3.08159999999998, 3.4076, 3.00639999999999, 2.38779999999997, 2.61900000000003, 1.99800000000005, 3.34820000000002, 2.95060000000001, 0.990999999999985, 2.11440000000005, 2.20299999999997, 2.82219999999995, 2.73239999999998, 2.7826, 3.76660000000004, 2.26480000000004, 2.31280000000004, 2.40819999999997, 2.75360000000001, 3.33759999999995, 2.71559999999999, 1.7478000000001, 1.42920000000004, 2.39300000000003, 2.22779999999989, 2.34339999999997, 0.87259999999992, 3.88400000000001, 1.80600000000004, 1.91759999999999, 1.16779999999994, 1.50320000000011, 2.52500000000009, 0.226400000000012, 2.31500000000005, 0.930000000000064, 1.25199999999995, 2.14959999999996, 0.0407999999999902, 2.5447999999999, 1.32960000000003, 0.197400000000016, 2.52620000000002, 3.33279999999991, -1.34300000000007, 0.422199999999975, 0.917200000000093, 1.12920000000008, 1.46060000000011, 1.45779999999991, 2.8728000000001, 3.33359999999993, -1.34079999999994, 1.57680000000005, 0.363000000000056, 1.40740000000005, 0.656600000000026, 0.801400000000058, -0.454600000000028, 1.51919999999996), + // precision 9 + Array(368, 361.8294, 355.2452, 348.6698, 342.1464, 336.2024, 329.8782, 323.6598, 317.462, 311.2826, 305.7102, 299.7416, 293.9366, 288.1046, 282.285, 277.0668, 271.306, 265.8448, 260.301, 254.9886, 250.2422, 244.8138, 239.7074, 234.7428, 229.8402, 225.1664, 220.3534, 215.594, 210.6886, 205.7876, 201.65, 197.228, 192.8036, 188.1666, 184.0818, 180.0824, 176.2574, 172.302, 168.1644, 164.0056, 160.3802, 156.7192, 152.5234, 149.2084, 145.831, 142.485, 139.1112, 135.4764, 131.76, 129.3368, 126.5538, 122.5058, 119.2646, 116.5902, 113.3818, 110.8998, 107.9532, 105.2062, 102.2798, 99.4728, 96.9582, 94.3292, 92.171, 89.7809999999999, 87.5716, 84.7048, 82.5322, 79.875, 78.3972, 75.3464, 73.7274, 71.2834, 70.1444, 68.4263999999999, 66.0166, 64.018, 62.0437999999999, 60.3399999999999, 58.6856, 57.9836, 55.0311999999999, 54.6769999999999, 52.3188, 51.4846, 49.4423999999999, 47.739, 46.1487999999999, 44.9202, 43.4059999999999, 42.5342000000001, 41.2834, 38.8954000000001, 38.3286000000001, 36.2146, 36.6684, 35.9946, 33.123, 33.4338, 31.7378000000001, 29.076, 28.9692, 27.4964, 27.0998, 25.9864, 26.7754, 24.3208, 23.4838, 22.7388000000001, 24.0758000000001, 21.9097999999999, 20.9728, 19.9228000000001, 19.9292, 16.617, 17.05, 18.2996000000001, 15.6128000000001, 15.7392, 14.5174, 13.6322, 12.2583999999999, 13.3766000000001, 11.423, 13.1232, 9.51639999999998, 10.5938000000001, 9.59719999999993, 8.12220000000002, 9.76739999999995, 7.50440000000003, 7.56999999999994, 6.70440000000008, 6.41419999999994, 6.71019999999999, 5.60940000000005, 4.65219999999999, 6.84099999999989, 3.4072000000001, 3.97859999999991, 3.32760000000007, 5.52160000000003, 3.31860000000006, 2.06940000000009, 4.35400000000004, 1.57500000000005, 0.280799999999999, 2.12879999999996, -0.214799999999968, -0.0378000000000611, -0.658200000000079, 0.654800000000023, -0.0697999999999865, 0.858400000000074, -2.52700000000004, -2.1751999999999, -3.35539999999992, -1.04019999999991, -0.651000000000067, -2.14439999999991, -1.96659999999997, -3.97939999999994, -0.604400000000169, -3.08260000000018, -3.39159999999993, -5.29640000000018, -5.38920000000007, -5.08759999999984, -4.69900000000007, -5.23720000000003, -3.15779999999995, -4.97879999999986, -4.89899999999989, -7.48880000000008, -5.94799999999987, -5.68060000000014, -6.67180000000008, -4.70499999999993, -7.27779999999984, -4.6579999999999, -4.4362000000001, -4.32139999999981, -5.18859999999995, -6.66879999999992, -6.48399999999992, -5.1260000000002, -4.4032000000002, -6.13500000000022, -5.80819999999994, -4.16719999999987, -4.15039999999999, -7.45600000000013, -7.24080000000004, -9.83179999999993, -5.80420000000004, -8.6561999999999, -6.99940000000015, -10.5473999999999, -7.34139999999979, -6.80999999999995, -6.29719999999998, -6.23199999999997), + // precision 10 + Array(737.1256, 724.4234, 711.1064, 698.4732, 685.4636, 673.0644, 660.488, 647.9654, 636.0832, 623.7864, 612.1992, 600.2176, 588.5228, 577.1716, 565.7752, 554.899, 543.6126, 532.6492, 521.9474, 511.5214, 501.1064, 490.6364, 480.2468, 470.4588, 460.3832, 451.0584, 440.8606, 431.3868, 422.5062, 413.1862, 404.463, 395.339, 386.1936, 378.1292, 369.1854, 361.2908, 353.3324, 344.8518, 337.5204, 329.4854, 321.9318, 314.552, 306.4658, 299.4256, 292.849, 286.152, 278.8956, 271.8792, 265.118, 258.62, 252.5132, 245.9322, 239.7726, 233.6086, 227.5332, 222.5918, 216.4294, 210.7662, 205.4106, 199.7338, 194.9012, 188.4486, 183.1556, 178.6338, 173.7312, 169.6264, 163.9526, 159.8742, 155.8326, 151.1966, 147.5594, 143.07, 140.037, 134.1804, 131.071, 127.4884, 124.0848, 120.2944, 117.333, 112.9626, 110.2902, 107.0814, 103.0334, 99.4832000000001, 96.3899999999999, 93.7202000000002, 90.1714000000002, 87.2357999999999, 85.9346, 82.8910000000001, 80.0264000000002, 78.3834000000002, 75.1543999999999, 73.8683999999998, 70.9895999999999, 69.4367999999999, 64.8701999999998, 65.0408000000002, 61.6738, 59.5207999999998, 57.0158000000001, 54.2302, 53.0962, 50.4985999999999, 52.2588000000001, 47.3914, 45.6244000000002, 42.8377999999998, 43.0072, 40.6516000000001, 40.2453999999998, 35.2136, 36.4546, 33.7849999999999, 33.2294000000002, 32.4679999999998, 30.8670000000002, 28.6507999999999, 28.9099999999999, 27.5983999999999, 26.1619999999998, 24.5563999999999, 23.2328000000002, 21.9484000000002, 21.5902000000001, 21.3346000000001, 17.7031999999999, 20.6111999999998, 19.5545999999999, 15.7375999999999, 17.0720000000001, 16.9517999999998, 15.326, 13.1817999999998, 14.6925999999999, 13.0859999999998, 13.2754, 10.8697999999999, 11.248, 7.3768, 4.72339999999986, 7.97899999999981, 8.7503999999999, 7.68119999999999, 9.7199999999998, 7.73919999999998, 5.6224000000002, 7.44560000000001, 6.6601999999998, 5.9058, 4.00199999999995, 4.51699999999983, 4.68240000000014, 3.86220000000003, 5.13639999999987, 5.98500000000013, 2.47719999999981, 2.61999999999989, 1.62800000000016, 4.65000000000009, 0.225599999999758, 0.831000000000131, -0.359400000000278, 1.27599999999984, -2.92559999999958, -0.0303999999996449, 2.37079999999969, -2.0033999999996, 0.804600000000391, 0.30199999999968, 1.1247999999996, -2.6880000000001, 0.0321999999996478, -1.18099999999959, -3.9402, -1.47940000000017, -0.188400000000001, -2.10720000000038, -2.04159999999956, -3.12880000000041, -4.16160000000036, -0.612799999999879, -3.48719999999958, -8.17900000000009, -5.37780000000021, -4.01379999999972, -5.58259999999973, -5.73719999999958, -7.66799999999967, -5.69520000000011, -1.1247999999996, -5.58520000000044, -8.04560000000038, -4.64840000000004, -11.6468000000004, -7.97519999999986, -5.78300000000036, -7.67420000000038, -10.6328000000003, -9.81720000000041), + // precision 11 + Array(1476, 1449.6014, 1423.5802, 1397.7942, 1372.3042, 1347.2062, 1321.8402, 1297.2292, 1272.9462, 1248.9926, 1225.3026, 1201.4252, 1178.0578, 1155.6092, 1132.626, 1110.5568, 1088.527, 1066.5154, 1045.1874, 1024.3878, 1003.37, 982.1972, 962.5728, 942.1012, 922.9668, 903.292, 884.0772, 864.8578, 846.6562, 828.041, 809.714, 792.3112, 775.1806, 757.9854, 740.656, 724.346, 707.5154, 691.8378, 675.7448, 659.6722, 645.5722, 630.1462, 614.4124, 600.8728, 585.898, 572.408, 558.4926, 544.4938, 531.6776, 517.282, 505.7704, 493.1012, 480.7388, 467.6876, 456.1872, 445.5048, 433.0214, 420.806, 411.409, 400.4144, 389.4294, 379.2286, 369.651, 360.6156, 350.337, 342.083, 332.1538, 322.5094, 315.01, 305.6686, 298.1678, 287.8116, 280.9978, 271.9204, 265.3286, 257.5706, 249.6014, 242.544, 235.5976, 229.583, 220.9438, 214.672, 208.2786, 201.8628, 195.1834, 191.505, 186.1816, 178.5188, 172.2294, 167.8908, 161.0194, 158.052, 151.4588, 148.1596, 143.4344, 138.5238, 133.13, 127.6374, 124.8162, 118.7894, 117.3984, 114.6078, 109.0858, 105.1036, 103.6258, 98.6018000000004, 95.7618000000002, 93.5821999999998, 88.5900000000001, 86.9992000000002, 82.8800000000001, 80.4539999999997, 74.6981999999998, 74.3644000000004, 73.2914000000001, 65.5709999999999, 66.9232000000002, 65.1913999999997, 62.5882000000001, 61.5702000000001, 55.7035999999998, 56.1764000000003, 52.7596000000003, 53.0302000000001, 49.0609999999997, 48.4694, 44.933, 46.0474000000004, 44.7165999999997, 41.9416000000001, 39.9207999999999, 35.6328000000003, 35.5276000000003, 33.1934000000001, 33.2371999999996, 33.3864000000003, 33.9228000000003, 30.2371999999996, 29.1373999999996, 25.2272000000003, 24.2942000000003, 19.8338000000003, 18.9005999999999, 23.0907999999999, 21.8544000000002, 19.5176000000001, 15.4147999999996, 16.9314000000004, 18.6737999999996, 12.9877999999999, 14.3688000000002, 12.0447999999997, 15.5219999999999, 12.5299999999997, 14.5940000000001, 14.3131999999996, 9.45499999999993, 12.9441999999999, 3.91139999999996, 13.1373999999996, 5.44720000000052, 9.82779999999912, 7.87279999999919, 3.67760000000089, 5.46980000000076, 5.55099999999948, 5.65979999999945, 3.89439999999922, 3.1275999999998, 5.65140000000065, 6.3062000000009, 3.90799999999945, 1.87060000000019, 5.17020000000048, 2.46680000000015, 0.770000000000437, -3.72340000000077, 1.16400000000067, 8.05340000000069, 0.135399999999208, 2.15940000000046, 0.766999999999825, 1.0594000000001, 3.15500000000065, -0.287399999999252, 2.37219999999979, -2.86620000000039, -1.63199999999961, -2.22979999999916, -0.15519999999924, -1.46039999999994, -0.262199999999211, -2.34460000000036, -2.8078000000005, -3.22179999999935, -5.60159999999996, -8.42200000000048, -9.43740000000071, 0.161799999999857, -10.4755999999998, -10.0823999999993), + // precision 12 + Array(2953, 2900.4782, 2848.3568, 2796.3666, 2745.324, 2694.9598, 2644.648, 2595.539, 2546.1474, 2498.2576, 2450.8376, 2403.6076, 2357.451, 2311.38, 2266.4104, 2221.5638, 2176.9676, 2134.193, 2090.838, 2048.8548, 2007.018, 1966.1742, 1925.4482, 1885.1294, 1846.4776, 1807.4044, 1768.8724, 1731.3732, 1693.4304, 1657.5326, 1621.949, 1586.5532, 1551.7256, 1517.6182, 1483.5186, 1450.4528, 1417.865, 1385.7164, 1352.6828, 1322.6708, 1291.8312, 1260.9036, 1231.476, 1201.8652, 1173.6718, 1145.757, 1119.2072, 1092.2828, 1065.0434, 1038.6264, 1014.3192, 988.5746, 965.0816, 940.1176, 917.9796, 894.5576, 871.1858, 849.9144, 827.1142, 805.0818, 783.9664, 763.9096, 742.0816, 724.3962, 706.3454, 688.018, 667.4214, 650.3106, 633.0686, 613.8094, 597.818, 581.4248, 563.834, 547.363, 531.5066, 520.455400000001, 505.583199999999, 488.366, 476.480799999999, 459.7682, 450.0522, 434.328799999999, 423.952799999999, 408.727000000001, 399.079400000001, 387.252200000001, 373.987999999999, 360.852000000001, 351.6394, 339.642, 330.902400000001, 322.661599999999, 311.662200000001, 301.3254, 291.7484, 279.939200000001, 276.7508, 263.215200000001, 254.811400000001, 245.5494, 242.306399999999, 234.8734, 223.787200000001, 217.7156, 212.0196, 200.793, 195.9748, 189.0702, 182.449199999999, 177.2772, 170.2336, 164.741, 158.613600000001, 155.311, 147.5964, 142.837, 137.3724, 132.0162, 130.0424, 121.9804, 120.451800000001, 114.8968, 111.585999999999, 105.933199999999, 101.705, 98.5141999999996, 95.0488000000005, 89.7880000000005, 91.4750000000004, 83.7764000000006, 80.9698000000008, 72.8574000000008, 73.1615999999995, 67.5838000000003, 62.6263999999992, 63.2638000000006, 66.0977999999996, 52.0843999999997, 58.9956000000002, 47.0912000000008, 46.4956000000002, 48.4383999999991, 47.1082000000006, 43.2392, 37.2759999999998, 40.0283999999992, 35.1864000000005, 35.8595999999998, 32.0998, 28.027, 23.6694000000007, 33.8266000000003, 26.3736000000008, 27.2008000000005, 21.3245999999999, 26.4115999999995, 23.4521999999997, 19.5013999999992, 19.8513999999996, 10.7492000000002, 18.6424000000006, 13.1265999999996, 18.2436000000016, 6.71860000000015, 3.39459999999963, 6.33759999999893, 7.76719999999841, 0.813999999998487, 3.82819999999992, 0.826199999999517, 8.07440000000133, -1.59080000000176, 5.01780000000144, 0.455399999998917, -0.24199999999837, 0.174800000000687, -9.07640000000174, -4.20160000000033, -3.77520000000004, -4.75179999999818, -5.3724000000002, -8.90680000000066, -6.10239999999976, -5.74120000000039, -9.95339999999851, -3.86339999999836, -13.7304000000004, -16.2710000000006, -7.51359999999841, -3.30679999999847, -13.1339999999982, -10.0551999999989, -6.72019999999975, -8.59660000000076, -10.9307999999983, -1.8775999999998, -4.82259999999951, -13.7788, -21.6470000000008, -10.6735999999983, -15.7799999999988), + // precision 13 + Array(5907.5052, 5802.2672, 5697.347, 5593.5794, 5491.2622, 5390.5514, 5290.3376, 5191.6952, 5093.5988, 4997.3552, 4902.5972, 4808.3082, 4715.5646, 4624.109, 4533.8216, 4444.4344, 4356.3802, 4269.2962, 4183.3784, 4098.292, 4014.79, 3932.4574, 3850.6036, 3771.2712, 3691.7708, 3615.099, 3538.1858, 3463.4746, 3388.8496, 3315.6794, 3244.5448, 3173.7516, 3103.3106, 3033.6094, 2966.5642, 2900.794, 2833.7256, 2769.81, 2707.3196, 2644.0778, 2583.9916, 2523.4662, 2464.124, 2406.073, 2347.0362, 2292.1006, 2238.1716, 2182.7514, 2128.4884, 2077.1314, 2025.037, 1975.3756, 1928.933, 1879.311, 1831.0006, 1783.2144, 1738.3096, 1694.5144, 1649.024, 1606.847, 1564.7528, 1525.3168, 1482.5372, 1443.9668, 1406.5074, 1365.867, 1329.2186, 1295.4186, 1257.9716, 1225.339, 1193.2972, 1156.3578, 1125.8686, 1091.187, 1061.4094, 1029.4188, 1000.9126, 972.3272, 944.004199999999, 915.7592, 889.965, 862.834200000001, 840.4254, 812.598399999999, 785.924200000001, 763.050999999999, 741.793799999999, 721.466, 699.040799999999, 677.997200000002, 649.866999999998, 634.911800000002, 609.8694, 591.981599999999, 570.2922, 557.129199999999, 538.3858, 521.872599999999, 502.951400000002, 495.776399999999, 475.171399999999, 459.751, 439.995200000001, 426.708999999999, 413.7016, 402.3868, 387.262599999998, 372.0524, 357.050999999999, 342.5098, 334.849200000001, 322.529399999999, 311.613799999999, 295.848000000002, 289.273000000001, 274.093000000001, 263.329600000001, 251.389599999999, 245.7392, 231.9614, 229.7952, 217.155200000001, 208.9588, 199.016599999999, 190.839199999999, 180.6976, 176.272799999999, 166.976999999999, 162.5252, 151.196400000001, 149.386999999999, 133.981199999998, 130.0586, 130.164000000001, 122.053400000001, 110.7428, 108.1276, 106.232400000001, 100.381600000001, 98.7668000000012, 86.6440000000002, 79.9768000000004, 82.4722000000002, 68.7026000000005, 70.1186000000016, 71.9948000000004, 58.998599999999, 59.0492000000013, 56.9818000000014, 47.5338000000011, 42.9928, 51.1591999999982, 37.2740000000013, 42.7220000000016, 31.3734000000004, 26.8090000000011, 25.8934000000008, 26.5286000000015, 29.5442000000003, 19.3503999999994, 26.0760000000009, 17.9527999999991, 14.8419999999969, 10.4683999999979, 8.65899999999965, 9.86720000000059, 4.34139999999752, -0.907800000000861, -3.32080000000133, -0.936199999996461, -11.9916000000012, -8.87000000000262, -6.33099999999831, -11.3366000000024, -15.9207999999999, -9.34659999999712, -15.5034000000014, -19.2097999999969, -15.357799999998, -28.2235999999975, -30.6898000000001, -19.3271999999997, -25.6083999999973, -24.409599999999, -13.6385999999984, -33.4473999999973, -32.6949999999997, -28.9063999999998, -31.7483999999968, -32.2935999999972, -35.8329999999987, -47.620600000002, -39.0855999999985, -33.1434000000008, -46.1371999999974, -37.5892000000022, -46.8164000000033, -47.3142000000007, -60.2914000000019, -37.7575999999972), + // precision 14 + Array(11816.475, 11605.0046, 11395.3792, 11188.7504, 10984.1814, 10782.0086, 10582.0072, 10384.503, 10189.178, 9996.2738, 9806.0344, 9617.9798, 9431.394, 9248.7784, 9067.6894, 8889.6824, 8712.9134, 8538.8624, 8368.4944, 8197.7956, 8031.8916, 7866.6316, 7703.733, 7544.5726, 7386.204, 7230.666, 7077.8516, 6926.7886, 6778.6902, 6631.9632, 6487.304, 6346.7486, 6206.4408, 6070.202, 5935.2576, 5799.924, 5671.0324, 5541.9788, 5414.6112, 5290.0274, 5166.723, 5047.6906, 4929.162, 4815.1406, 4699.127, 4588.5606, 4477.7394, 4369.4014, 4264.2728, 4155.9224, 4055.581, 3955.505, 3856.9618, 3761.3828, 3666.9702, 3575.7764, 3482.4132, 3395.0186, 3305.8852, 3221.415, 3138.6024, 3056.296, 2970.4494, 2896.1526, 2816.8008, 2740.2156, 2670.497, 2594.1458, 2527.111, 2460.8168, 2387.5114, 2322.9498, 2260.6752, 2194.2686, 2133.7792, 2074.767, 2015.204, 1959.4226, 1898.6502, 1850.006, 1792.849, 1741.4838, 1687.9778, 1638.1322, 1589.3266, 1543.1394, 1496.8266, 1447.8516, 1402.7354, 1361.9606, 1327.0692, 1285.4106, 1241.8112, 1201.6726, 1161.973, 1130.261, 1094.2036, 1048.2036, 1020.6436, 990.901400000002, 961.199800000002, 924.769800000002, 899.526400000002, 872.346400000002, 834.375, 810.432000000001, 780.659800000001, 756.013800000001, 733.479399999997, 707.923999999999, 673.858, 652.222399999999, 636.572399999997, 615.738599999997, 586.696400000001, 564.147199999999, 541.679600000003, 523.943599999999, 505.714599999999, 475.729599999999, 461.779600000002, 449.750800000002, 439.020799999998, 412.7886, 400.245600000002, 383.188199999997, 362.079599999997, 357.533799999997, 334.319000000003, 327.553399999997, 308.559399999998, 291.270199999999, 279.351999999999, 271.791400000002, 252.576999999997, 247.482400000001, 236.174800000001, 218.774599999997, 220.155200000001, 208.794399999999, 201.223599999998, 182.995600000002, 185.5268, 164.547400000003, 176.5962, 150.689599999998, 157.8004, 138.378799999999, 134.021200000003, 117.614399999999, 108.194000000003, 97.0696000000025, 89.6042000000016, 95.6030000000028, 84.7810000000027, 72.635000000002, 77.3482000000004, 59.4907999999996, 55.5875999999989, 50.7346000000034, 61.3916000000027, 50.9149999999936, 39.0384000000049, 58.9395999999979, 29.633600000001, 28.2032000000036, 26.0078000000067, 17.0387999999948, 9.22000000000116, 13.8387999999977, 8.07240000000456, 14.1549999999988, 15.3570000000036, 3.42660000000615, 6.24820000000182, -2.96940000000177, -8.79940000000352, -5.97860000000219, -14.4048000000039, -3.4143999999942, -13.0148000000045, -11.6977999999945, -25.7878000000055, -22.3185999999987, -24.409599999999, -31.9756000000052, -18.9722000000038, -22.8678000000073, -30.8972000000067, -32.3715999999986, -22.3907999999938, -43.6720000000059, -35.9038, -39.7492000000057, -54.1641999999993, -45.2749999999942, -42.2989999999991, -44.1089999999967, -64.3564000000042, -49.9551999999967, -42.6116000000038), + // precision 15 + Array(23634.0036, 23210.8034, 22792.4744, 22379.1524, 21969.7928, 21565.326, 21165.3532, 20770.2806, 20379.9892, 19994.7098, 19613.318, 19236.799, 18865.4382, 18498.8244, 18136.5138, 17778.8668, 17426.2344, 17079.32, 16734.778, 16397.2418, 16063.3324, 15734.0232, 15409.731, 15088.728, 14772.9896, 14464.1402, 14157.5588, 13855.5958, 13559.3296, 13264.9096, 12978.326, 12692.0826, 12413.8816, 12137.3192, 11870.2326, 11602.5554, 11340.3142, 11079.613, 10829.5908, 10583.5466, 10334.0344, 10095.5072, 9859.694, 9625.2822, 9395.7862, 9174.0586, 8957.3164, 8738.064, 8524.155, 8313.7396, 8116.9168, 7913.542, 7718.4778, 7521.65, 7335.5596, 7154.2906, 6968.7396, 6786.3996, 6613.236, 6437.406, 6270.6598, 6107.7958, 5945.7174, 5787.6784, 5635.5784, 5482.308, 5337.9784, 5190.0864, 5045.9158, 4919.1386, 4771.817, 4645.7742, 4518.4774, 4385.5454, 4262.6622, 4142.74679999999, 4015.5318, 3897.9276, 3790.7764, 3685.13800000001, 3573.6274, 3467.9706, 3368.61079999999, 3271.5202, 3170.3848, 3076.4656, 2982.38400000001, 2888.4664, 2806.4868, 2711.9564, 2634.1434, 2551.3204, 2469.7662, 2396.61139999999, 2318.9902, 2243.8658, 2171.9246, 2105.01360000001, 2028.8536, 1960.9952, 1901.4096, 1841.86079999999, 1777.54700000001, 1714.5802, 1654.65059999999, 1596.311, 1546.2016, 1492.3296, 1433.8974, 1383.84600000001, 1339.4152, 1293.5518, 1245.8686, 1193.50659999999, 1162.27959999999, 1107.19439999999, 1069.18060000001, 1035.09179999999, 999.679000000004, 957.679999999993, 925.300199999998, 888.099400000006, 848.638600000006, 818.156400000007, 796.748399999997, 752.139200000005, 725.271200000003, 692.216, 671.633600000001, 647.939799999993, 621.670599999998, 575.398799999995, 561.226599999995, 532.237999999998, 521.787599999996, 483.095799999996, 467.049599999998, 465.286399999997, 415.548599999995, 401.047399999996, 380.607999999993, 377.362599999993, 347.258799999996, 338.371599999999, 310.096999999994, 301.409199999995, 276.280799999993, 265.586800000005, 258.994399999996, 223.915999999997, 215.925399999993, 213.503800000006, 191.045400000003, 166.718200000003, 166.259000000005, 162.941200000001, 148.829400000002, 141.645999999993, 123.535399999993, 122.329800000007, 89.473399999988, 80.1962000000058, 77.5457999999926, 59.1056000000099, 83.3509999999951, 52.2906000000075, 36.3979999999865, 40.6558000000077, 42.0003999999899, 19.6630000000005, 19.7153999999864, -8.38539999999921, -0.692799999989802, 0.854800000000978, 3.23219999999856, -3.89040000000386, -5.25880000001052, -24.9052000000083, -22.6837999999989, -26.4286000000138, -34.997000000003, -37.0216000000073, -43.430400000012, -58.2390000000014, -68.8034000000043, -56.9245999999985, -57.8583999999973, -77.3097999999882, -73.2793999999994, -81.0738000000129, -87.4530000000086, -65.0254000000132, -57.296399999992, -96.2746000000043, -103.25, -96.081600000005, -91.5542000000132, -102.465200000006, -107.688599999994, -101.458000000013, -109.715800000005), + // precision 16 + Array(47270, 46423.3584, 45585.7074, 44757.152, 43938.8416, 43130.9514, 42330.03, 41540.407, 40759.6348, 39988.206, 39226.5144, 38473.2096, 37729.795, 36997.268, 36272.6448, 35558.665, 34853.0248, 34157.4472, 33470.5204, 32793.5742, 32127.0194, 31469.4182, 30817.6136, 30178.6968, 29546.8908, 28922.8544, 28312.271, 27707.0924, 27114.0326, 26526.692, 25948.6336, 25383.7826, 24823.5998, 24272.2974, 23732.2572, 23201.4976, 22674.2796, 22163.6336, 21656.515, 21161.7362, 20669.9368, 20189.4424, 19717.3358, 19256.3744, 18795.9638, 18352.197, 17908.5738, 17474.391, 17052.918, 16637.2236, 16228.4602, 15823.3474, 15428.6974, 15043.0284, 14667.6278, 14297.4588, 13935.2882, 13578.5402, 13234.6032, 12882.1578, 12548.0728, 12219.231, 11898.0072, 11587.2626, 11279.9072, 10973.5048, 10678.5186, 10392.4876, 10105.2556, 9825.766, 9562.5444, 9294.2222, 9038.2352, 8784.848, 8533.2644, 8301.7776, 8058.30859999999, 7822.94579999999, 7599.11319999999, 7366.90779999999, 7161.217, 6957.53080000001, 6736.212, 6548.21220000001, 6343.06839999999, 6156.28719999999, 5975.15419999999, 5791.75719999999, 5621.32019999999, 5451.66, 5287.61040000001, 5118.09479999999, 4957.288, 4798.4246, 4662.17559999999, 4512.05900000001, 4364.68539999999, 4220.77720000001, 4082.67259999999, 3957.19519999999, 3842.15779999999, 3699.3328, 3583.01180000001, 3473.8964, 3338.66639999999, 3233.55559999999, 3117.799, 3008.111, 2909.69140000001, 2814.86499999999, 2719.46119999999, 2624.742, 2532.46979999999, 2444.7886, 2370.1868, 2272.45259999999, 2196.19260000001, 2117.90419999999, 2023.2972, 1969.76819999999, 1885.58979999999, 1833.2824, 1733.91200000001, 1682.54920000001, 1604.57980000001, 1556.11240000001, 1491.3064, 1421.71960000001, 1371.22899999999, 1322.1324, 1264.7892, 1196.23920000001, 1143.8474, 1088.67240000001, 1073.60380000001, 1023.11660000001, 959.036400000012, 927.433199999999, 906.792799999996, 853.433599999989, 841.873800000001, 791.1054, 756.899999999994, 704.343200000003, 672.495599999995, 622.790399999998, 611.254799999995, 567.283200000005, 519.406599999988, 519.188400000014, 495.312800000014, 451.350799999986, 443.973399999988, 431.882199999993, 392.027000000002, 380.924200000009, 345.128999999986, 298.901400000002, 287.771999999997, 272.625, 247.253000000026, 222.490600000019, 223.590000000026, 196.407599999977, 176.425999999978, 134.725199999986, 132.4804, 110.445599999977, 86.7939999999944, 56.7038000000175, 64.915399999998, 38.3726000000024, 37.1606000000029, 46.170999999973, 49.1716000000015, 15.3362000000197, 6.71639999997569, -34.8185999999987, -39.4476000000141, 12.6830000000191, -12.3331999999937, -50.6565999999875, -59.9538000000175, -65.1054000000004, -70.7576000000117, -106.325200000021, -126.852200000023, -110.227599999984, -132.885999999999, -113.897200000007, -142.713800000027, -151.145399999979, -150.799200000009, -177.756200000003, -156.036399999983, -182.735199999996, -177.259399999981, -198.663600000029, -174.577600000019, -193.84580000001), + // precision 17 + Array(94541, 92848.811, 91174.019, 89517.558, 87879.9705, 86262.7565, 84663.5125, 83083.7435, 81521.7865, 79977.272, 78455.9465, 76950.219, 75465.432, 73994.152, 72546.71, 71115.2345, 69705.6765, 68314.937, 66944.2705, 65591.255, 64252.9485, 62938.016, 61636.8225, 60355.592, 59092.789, 57850.568, 56624.518, 55417.343, 54231.1415, 53067.387, 51903.526, 50774.649, 49657.6415, 48561.05, 47475.7575, 46410.159, 45364.852, 44327.053, 43318.4005, 42325.6165, 41348.4595, 40383.6265, 39436.77, 38509.502, 37594.035, 36695.939, 35818.6895, 34955.691, 34115.8095, 33293.949, 32465.0775, 31657.6715, 30877.2585, 30093.78, 29351.3695, 28594.1365, 27872.115, 27168.7465, 26477.076, 25774.541, 25106.5375, 24452.5135, 23815.5125, 23174.0655, 22555.2685, 21960.2065, 21376.3555, 20785.1925, 20211.517, 19657.0725, 19141.6865, 18579.737, 18081.3955, 17578.995, 17073.44, 16608.335, 16119.911, 15651.266, 15194.583, 14749.0495, 14343.4835, 13925.639, 13504.509, 13099.3885, 12691.2855, 12328.018, 11969.0345, 11596.5145, 11245.6355, 10917.6575, 10580.9785, 10277.8605, 9926.58100000001, 9605.538, 9300.42950000003, 8989.97850000003, 8728.73249999998, 8448.3235, 8175.31050000002, 7898.98700000002, 7629.79100000003, 7413.76199999999, 7149.92300000001, 6921.12650000001, 6677.1545, 6443.28000000003, 6278.23450000002, 6014.20049999998, 5791.20299999998, 5605.78450000001, 5438.48800000001, 5234.2255, 5059.6825, 4887.43349999998, 4682.935, 4496.31099999999, 4322.52250000002, 4191.42499999999, 4021.24200000003, 3900.64799999999, 3762.84250000003, 3609.98050000001, 3502.29599999997, 3363.84250000003, 3206.54849999998, 3079.70000000001, 2971.42300000001, 2867.80349999998, 2727.08100000001, 2630.74900000001, 2496.6165, 2440.902, 2356.19150000002, 2235.58199999999, 2120.54149999999, 2012.25449999998, 1933.35600000003, 1820.93099999998, 1761.54800000001, 1663.09350000002, 1578.84600000002, 1509.48149999999, 1427.3345, 1379.56150000001, 1306.68099999998, 1212.63449999999, 1084.17300000001, 1124.16450000001, 1060.69949999999, 1007.48849999998, 941.194499999983, 879.880500000028, 836.007500000007, 782.802000000025, 748.385499999975, 647.991500000004, 626.730500000005, 570.776000000013, 484.000500000024, 513.98550000001, 418.985499999952, 386.996999999974, 370.026500000036, 355.496999999974, 356.731499999994, 255.92200000002, 259.094000000041, 205.434499999974, 165.374500000034, 197.347500000033, 95.718499999959, 67.6165000000037, 54.6970000000438, 31.7395000000251, -15.8784999999916, 8.42500000004657, -26.3754999999655, -118.425500000012, -66.6629999999423, -42.9745000000112, -107.364999999991, -189.839000000036, -162.611499999999, -164.964999999967, -189.079999999958, -223.931499999948, -235.329999999958, -269.639500000048, -249.087999999989, -206.475499999942, -283.04449999996, -290.667000000016, -304.561499999953, -336.784499999951, -380.386500000022, -283.280499999993, -364.533000000054, -389.059499999974, -364.454000000027, -415.748000000021, -417.155000000028), + // precision 18 + Array(189083, 185696.913, 182348.774, 179035.946, 175762.762, 172526.444, 169329.754, 166166.099, 163043.269, 159958.91, 156907.912, 153906.845, 150924.199, 147996.568, 145093.457, 142239.233, 139421.475, 136632.27, 133889.588, 131174.2, 128511.619, 125868.621, 123265.385, 120721.061, 118181.769, 115709.456, 113252.446, 110840.198, 108465.099, 106126.164, 103823.469, 101556.618, 99308.004, 97124.508, 94937.803, 92833.731, 90745.061, 88677.627, 86617.47, 84650.442, 82697.833, 80769.132, 78879.629, 77014.432, 75215.626, 73384.587, 71652.482, 69895.93, 68209.301, 66553.669, 64921.981, 63310.323, 61742.115, 60205.018, 58698.658, 57190.657, 55760.865, 54331.169, 52908.167, 51550.273, 50225.254, 48922.421, 47614.533, 46362.049, 45098.569, 43926.083, 42736.03, 41593.473, 40425.26, 39316.237, 38243.651, 37170.617, 36114.609, 35084.19, 34117.233, 33206.509, 32231.505, 31318.728, 30403.404, 29540.0550000001, 28679.236, 27825.862, 26965.216, 26179.148, 25462.08, 24645.952, 23922.523, 23198.144, 22529.128, 21762.4179999999, 21134.779, 20459.117, 19840.818, 19187.04, 18636.3689999999, 17982.831, 17439.7389999999, 16874.547, 16358.2169999999, 15835.684, 15352.914, 14823.681, 14329.313, 13816.897, 13342.874, 12880.882, 12491.648, 12021.254, 11625.392, 11293.7610000001, 10813.697, 10456.209, 10099.074, 9755.39000000001, 9393.18500000006, 9047.57900000003, 8657.98499999999, 8395.85900000005, 8033, 7736.95900000003, 7430.59699999995, 7258.47699999996, 6924.58200000005, 6691.29399999999, 6357.92500000005, 6202.05700000003, 5921.19700000004, 5628.28399999999, 5404.96799999999, 5226.71100000001, 4990.75600000005, 4799.77399999998, 4622.93099999998, 4472.478, 4171.78700000001, 3957.46299999999, 3868.95200000005, 3691.14300000004, 3474.63100000005, 3341.67200000002, 3109.14000000001, 3071.97400000005, 2796.40399999998, 2756.17799999996, 2611.46999999997, 2471.93000000005, 2382.26399999997, 2209.22400000005, 2142.28399999999, 2013.96100000001, 1911.18999999994, 1818.27099999995, 1668.47900000005, 1519.65800000005, 1469.67599999998, 1367.13800000004, 1248.52899999998, 1181.23600000003, 1022.71900000004, 1088.20700000005, 959.03600000008, 876.095999999903, 791.183999999892, 703.337000000058, 731.949999999953, 586.86400000006, 526.024999999907, 323.004999999888, 320.448000000091, 340.672999999952, 309.638999999966, 216.601999999955, 102.922999999952, 19.2399999999907, -0.114000000059605, -32.6240000000689, -89.3179999999702, -153.497999999905, -64.2970000000205, -143.695999999996, -259.497999999905, -253.017999999924, -213.948000000091, -397.590000000084, -434.006000000052, -403.475000000093, -297.958000000101, -404.317000000039, -528.898999999976, -506.621000000043, -513.205000000075, -479.351000000024, -596.139999999898, -527.016999999993, -664.681000000099, -680.306000000099, -704.050000000047, -850.486000000034, -757.43200000003, -713.308999999892) + ) + // scalastyle:on + + private def validateDoubleLiteral(exp: Expression): Double = exp match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala new file mode 100644 index 000000000000..c2bf2cb94116 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class Kurtosis(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "kurtosis" + + override protected val momentOrder = 4 + + // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m4 = moments(4) + + if (n == 0.0) { + null + } else if (m2 == 0.0) { + Double.NaN + } else { + n * m4 / (m2 * m2) - 3.0 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala new file mode 100644 index 000000000000..be7e12d7a233 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Returns the last value of `child` for a group of rows. If the last value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + */ +case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private lazy val last = AttributeReference("last", child.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override lazy val updateExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(child), last, child) + ) + } else { + Seq( + /* last = */ child + ) + } + } + + override lazy val mergeExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + } else { + Seq( + /* last = */ last.right + ) + } + } + + override lazy val evaluateExpression: AttributeReference = last + + override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala new file mode 100644 index 000000000000..906003188d4f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +case class Max(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") + + private lazy val max = AttributeReference("max", child.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* max = */ Greatest(Seq(max, child)) + ) + + override lazy val mergeExpressions: Seq[Expression] = { + Seq( + /* max = */ Greatest(Seq(max.left, max.right)) + ) + } + + override lazy val evaluateExpression: AttributeReference = max +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala new file mode 100644 index 000000000000..39f7afbd081c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + + +case class Min(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") + + private lazy val min = AttributeReference("min", child.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* min = */ Least(Seq(min, child)) + ) + + override lazy val mergeExpressions: Seq[Expression] = { + Seq( + /* min = */ Least(Seq(min.left, min.right)) + ) + } + + override lazy val evaluateExpression: AttributeReference = min +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala new file mode 100644 index 000000000000..9411bcea2539 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class Skewness(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "skewness" + + override protected val momentOrder = 3 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m3 = moments(3) + + if (n == 0.0) { + null + } else if (m2 == 0.0) { + Double.NaN + } else { + math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala new file mode 100644 index 000000000000..eec79a9033e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class StddevSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "stddev_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } + } +} + +case class StddevPop( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "stddev_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala new file mode 100644 index 000000000000..cfb042e0aa78 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +case class Sum(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType)) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType + case _ => child.dataType + } + + private lazy val sumDataType = resultType + + private lazy val sum = AttributeReference("sum", sumDataType)() + + private lazy val zero = Cast(Literal(0), sumDataType) + + override lazy val aggBufferAttributes = sum :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + /* sum = */ Literal.create(null, sumDataType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + + override lazy val mergeExpressions: Seq[Expression] = { + val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) + Seq( + /* sum = */ + Coalesce(Seq(add, sum.left)) + ) + } + + override lazy val evaluateExpression: Expression = Cast(sum, resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala new file mode 100644 index 000000000000..cf3a74030539 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class VarianceSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "var_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } + } +} + +case class VariancePop( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "var_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) { + null + } else { + moments(2) / n + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala deleted file mode 100644 index 88fb516e64aa..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -case class Average(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Once we remove the old code path, we can use our analyzer to cast NullType - // to the default data type of the NumericType. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = child.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 4, s + 4) - case _ => DoubleType - } - - private val sumDataType = child.dataType match { - case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) - case _ => DoubleType - } - - private val currentSum = AttributeReference("currentSum", sumDataType)() - private val currentCount = AttributeReference("currentCount", LongType)() - - override val bufferAttributes = currentSum :: currentCount :: Nil - - override val initialValues = Seq( - /* currentSum = */ Cast(Literal(0), sumDataType), - /* currentCount = */ Literal(0L) - ) - - override val updateExpressions = Seq( - /* currentSum = */ - Add( - currentSum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) - ) - - override val mergeExpressions = Seq( - /* currentSum = */ currentSum.left + currentSum.right, - /* currentCount = */ currentCount.left + currentCount.right - ) - - // If all input are nulls, currentCount will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) - case _ => - Cast(currentSum, resultType) / Cast(currentCount, resultType) - } -} - -case class Count(child: Expression) extends AlgebraicAggregate { - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = false - - // Return data type. - override def dataType: DataType = LongType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val currentCount = AttributeReference("currentCount", LongType)() - - override val bufferAttributes = currentCount :: Nil - - override val initialValues = Seq( - /* currentCount = */ Literal(0L) - ) - - override val updateExpressions = Seq( - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) - ) - - override val mergeExpressions = Seq( - /* currentCount = */ currentCount.left + currentCount.right - ) - - override val evaluateExpression = Cast(currentCount, LongType) -} - -case class First(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // First is not a deterministic function. - override def deterministic: Boolean = false - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val first = AttributeReference("first", child.dataType)() - - override val bufferAttributes = first :: Nil - - override val initialValues = Seq( - /* first = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* first = */ If(IsNull(first), child, first) - ) - - override val mergeExpressions = Seq( - /* first = */ If(IsNull(first.left), first.right, first.left) - ) - - override val evaluateExpression = first -} - -case class Last(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Last is not a deterministic function. - override def deterministic: Boolean = false - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val last = AttributeReference("last", child.dataType)() - - override val bufferAttributes = last :: Nil - - override val initialValues = Seq( - /* last = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* last = */ If(IsNull(child), last, child) - ) - - override val mergeExpressions = Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - - override val evaluateExpression = last -} - -case class Max(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val max = AttributeReference("max", child.dataType)() - - override val bufferAttributes = max :: Nil - - override val initialValues = Seq( - /* max = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) - ) - - override val mergeExpressions = { - val greatest = Greatest(Seq(max.left, max.right)) - Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) - ) - } - - override val evaluateExpression = max -} - -case class Min(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val min = AttributeReference("min", child.dataType)() - - override val bufferAttributes = min :: Nil - - override val initialValues = Seq( - /* min = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) - ) - - override val mergeExpressions = { - val least = Least(Seq(min.left, min.right)) - Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) - ) - } - - override val evaluateExpression = min -} - -case class Sum(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) - - private val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => child.dataType - } - - private val sumDataType = resultType - - private val currentSum = AttributeReference("currentSum", sumDataType)() - - private val zero = Cast(Literal(0), sumDataType) - - override val bufferAttributes = currentSum :: Nil - - override val initialValues = Seq( - /* currentSum = */ Literal.create(null, sumDataType) - ) - - override val updateExpressions = Seq( - /* currentSum = */ - Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) - ) - - override val mergeExpressions = { - val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) - Seq( - /* currentSum = */ - Coalesce(Seq(add, currentSum.left)) - ) - } - - override val evaluateExpression = Cast(currentSum, resultType) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 4abfdfe87d5e..b6d2ddc5b136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction2]]. */ +/** The mode of an [[AggregateFunction]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly + * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -68,16 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable { } /** - * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. - * @param aggregateFunction - * @param mode - * @param isDistinct */ -private[sql] case class AggregateExpression2( - aggregateFunction: AggregateFunction2, +private[sql] case class AggregateExpression( + aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) extends AggregateExpression { + isDistinct: Boolean) + extends Expression + with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -87,101 +86,241 @@ private[sql] case class AggregateExpression2( override def references: AttributeSet = { val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.bufferAttributes + case PartialMerge | Final => aggregateFunction.aggBufferAttributes } AttributeSet(childReferences) } - override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" + override def prettyString: String = aggregateFunction.prettyString + + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } -abstract class AggregateFunction2 - extends Expression with ImplicitCastInputTypes { +/** + * AggregateFunction2 is the superclass of two aggregation function interfaces: + * + * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of + * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. + * - [[DeclarativeAggregate]] is for aggregation functions that are specified using + * Catalyst expressions. + * + * In both interfaces, aggregates must define the schema ([[aggBufferSchema]]) and attributes + * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate + * results. At runtime, multiple aggregate functions are evaluated by the same operator using a + * combined aggregation buffer which concatenates the aggregation buffers of the individual + * aggregate functions. + * + * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of + * aggregate functions. + */ +sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false + /** The schema of the aggregation buffer. */ + def aggBufferSchema: StructType + + /** Attributes of fields in aggBufferSchema. */ + def aggBufferAttributes: Seq[AttributeReference] + /** - * The offset of this function's start buffer value in the - * underlying shared mutable aggregation buffer. - * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share - * the same aggregation buffer. In this shared buffer, the position of the first - * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` - * will be 2. + * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are + * merged with mutable aggregation buffers in the merge() function or merge expressions). + * These attributes are created automatically by cloning the [[aggBufferAttributes]]. */ - protected var mutableBufferOffset: Int = 0 + def inputAggBufferAttributes: Seq[AttributeReference] - def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { - mutableBufferOffset = newMutableBufferOffset - } + /** + * Indicates if this function supports partial aggregation. + * Currently Hive UDAF is the only one that doesn't support partial aggregation. + */ + def supportsPartial: Boolean = true /** - * The offset of this function's start buffer value in the - * underlying shared input aggregation buffer. An input aggregation buffer is used - * when we merge two aggregation buffers and it is basically the immutable one - * (we merge an input aggregation buffer and a mutable aggregation buffer and - * then store the new buffer values to the mutable aggregation buffer). - * Usually, an input aggregation buffer also contain extra elements like grouping - * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often - * different. - * For example, we have a grouping expression `key``, and two aggregate functions - * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first - * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` - * will be 3 (position 0 is used for the value of key`). + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. */ - protected var inputBufferOffset: Int = 0 + def defaultResult: Option[Literal] = None - def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { - inputBufferOffset = newInputBufferOffset + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct + * field of the [[AggregateExpression]] to the given value because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { + AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } +} - /** The schema of the aggregation buffer. */ - def bufferSchema: StructType +/** + * API for aggregation functions that are expressed in terms of imperative initialize(), update(), + * and merge() functions which operate on Row-based aggregation buffers. + * + * Within these functions, code should access fields of the mutable aggregation buffer by adding the + * bufferSchema-relative field number to `mutableAggBufferOffset` then using this new field number + * to access the buffer Row. This is necessary because this aggregation function's buffer is + * embedded inside of a larger shared aggregation buffer when an aggregation operator evaluates + * multiple aggregate functions at the same time. + * + * We need to perform similar field number arithmetic when merging multiple intermediate + * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing + * the input buffer). + * + * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and + * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` + * and `inputAggBufferAttributes`. + */ +abstract class ImperativeAggregate extends AggregateFunction with CodegenFallback { - /** Attributes of fields in bufferSchema. */ - def bufferAttributes: Seq[AttributeReference] + /** + * The offset of this function's first buffer value in the underlying shared mutable aggregation + * buffer. + * + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same + * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` + * will be 0 and the position of the first buffer value of `avg(y)` will be 2: + * + * avg(x) mutableAggBufferOffset = 0 + * | + * v + * +--------+--------+--------+--------+ + * | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+ + * ^ + * | + * avg(y) mutableAggBufferOffset = 2 + * + */ + protected val mutableAggBufferOffset: Int + + /** + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. + */ + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate - /** Clones bufferAttributes. */ - def cloneBufferAttributes: Seq[Attribute] + /** + * The offset of this function's start buffer value in the underlying shared input aggregation + * buffer. An input aggregation buffer is used when we merge two aggregation buffers together in + * the `update()` function and is immutable (we merge an input aggregation buffer and a mutable + * aggregation buffer and then store the new buffer values to the mutable aggregation buffer). + * + * An input aggregation buffer may contain extra fields, such as grouping keys, at its start, so + * mutableAggBufferOffset and inputAggBufferOffset are often different. + * + * For example, say we have a grouping expression, `key`, and two aggregate functions, + * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of `key`): + * + * avg(x) inputAggBufferOffset = 1 + * | + * v + * +--------+--------+--------+--------+--------+ + * | key | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+--------+ + * ^ + * | + * avg(y) inputAggBufferOffset = 3 + * + */ + protected val inputAggBufferOffset: Int /** - * Initializes its aggregation buffer located in `buffer`. - * It will use bufferOffset to find the starting point of - * its buffer in the given `buffer` shared with other functions. + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. */ - def initialize(buffer: MutableRow): Unit + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate + + // Note: although all subclasses implement inputAggBufferAttributes by simply cloning + // aggBufferAttributes, that common clone code cannot be placed here in the abstract + // ImperativeAggregate class, since that will lead to initialization ordering issues. /** - * Updates its aggregation buffer located in `buffer` based on the given `input`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer` - * shared with other functions. + * Initializes the mutable aggregation buffer located in `mutableAggBuffer`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def update(buffer: MutableRow, input: InternalRow): Unit + def initialize(mutableAggBuffer: MutableRow): Unit /** - * Updates its aggregation buffer located in `buffer1` by combining intermediate results - * in the current buffer and intermediate results from another buffer `buffer2`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` - * and `buffer2`. + * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + /** + * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate + * results in the `mutableAggBuffer.` + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + */ + def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit } /** - * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + * API for aggregation functions that are expressed in terms of Catalyst expressions. + * + * When implementing a new expression-based aggregate function, start by implementing + * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You + * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and + * `evaluateExpressions`. + * + * Please note that children of an aggregate function can be unresolved (it will happen when + * we create this function in DataFrame API). So, if there is any fields in + * the implemented class that need to access fields of its children, please make + * those fields `lazy val`s. */ -abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable { +abstract class DeclarativeAggregate + extends AggregateFunction + with Serializable + with Unevaluable { + /** + * Expressions for initializing empty aggregation buffers. + */ val initialValues: Seq[Expression] + + /** + * Expressions for updating the mutable aggregation buffer based on an input row. + */ val updateExpressions: Seq[Expression] + + /** + * A sequence of expressions for merging two aggregation buffers together. When defining these + * expressions, you can use the syntax `attributeName.left` and `attributeName.right` to refer + * to the attributes corresponding to each of the buffers being merged (this magic is enabled + * by the [[RichAttribute]] implicit class). + */ val mergeExpressions: Seq[Expression] + + /** + * An expression which returns the final value for this aggregate function. Its data type should + * match this expression's [[dataType]]. + */ val evaluateExpression: Expression - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ + final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) /** * A helper class for representing an attribute used in merging two @@ -189,33 +328,13 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w * we merge buffer values and then update bufferLeft. A [[RichAttribute]] * of an [[AttributeReference]] `a` has two functions `left` and `right`, * which represent `a` in `bufferLeft` and `bufferRight`, respectively. - * @param a */ implicit class RichAttribute(a: AttributeReference) { /** Represents this attribute at the mutable buffer side. */ def left: AttributeReference = a /** Represents this attribute at the input buffer side (the data value is read-only). */ - def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) - } - - /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ - override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) - - override def initialize(buffer: MutableRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's initialize should not be called directly") - } - - override final def update(buffer: MutableRow, input: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's update should not be called directly") + def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } - - override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's merge should not be called directly") - } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala deleted file mode 100644 index 4a43318a9549..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} - -/** - * Utility functions used by the query planner to convert our plan to new aggregation code path. - */ -object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to see if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate => - val converted = doConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case other => None - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala deleted file mode 100644 index 5d4b349b1597..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ /dev/null @@ -1,688 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import com.clearspring.analytics.stream.cardinality.HyperLogLog - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - - -trait AggregateExpression extends Expression with Unevaluable - -trait AggregateExpression1 extends AggregateExpression { - - /** - * Aggregate expressions should not be foldable. - */ - override def foldable: Boolean = false - - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction1 -} - -/** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. - */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) - -/** - * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. - */ -trait PartialAggregate1 extends AggregateExpression1 { - - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} - -/** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. - */ -abstract class AggregateFunction1 extends LeafExpression with Serializable { - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression1 - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType - - def update(input: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - throw new UnsupportedOperationException( - "AggregateFunction1 should not be used for generated aggregates") - } -} - -case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") -} - -case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: InternalRow): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMin.value -} - -case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") -} - -case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: InternalRow): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMax.value -} - -case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) - } - - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var count: Long = _ - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } - } - - override def eval(input: InternalRow): Any = count -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = - new CollectHashSetFunction(expressions, this) -} - -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) - } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ -private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { - - override def sqlType: DataType = BinaryType - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def serialize(obj: Any): Array[Byte] = - obj.asInstanceOf[HyperLogLog].getBytes - - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def deserialize(datum: Any): HyperLogLog = - HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) - - override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: DataType = HyperLogLogUDT - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } - - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) - } - - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) -} - -case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def prettyName: String = "avg" - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -case class AverageFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) - - private def addFunction(value: Any) = Add(sum, - Cast(Literal.create(value, expr.dataType), calcType)) - - override def eval(input: InternalRow): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - val dt = DecimalType.bounded(precision + 14, scale + 4) - Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) - } - } - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) - } - } -} - -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") -} - -case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) - - override def update(input: InternalRow): Unit = { - sum.update(addFunction, input) - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") -} - -case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val seen = new scala.collection.mutable.HashSet[Any]() - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } - } - - override def eval(input: InternalRow): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) - } - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next()) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.genericGet(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - if (result == null) { - result = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = result -} - -case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) -} - -case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - result = input - } - - override def eval(input: InternalRow): Any = { - if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0891b5549471..61a17fd7db0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -42,7 +42,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp // for example, we could not write --9223372036854775808L in code s""" ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue)); + ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -223,20 +223,20 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitive}.isZero()" + s"${eval2.value}.isZero()" } else { - s"${eval2.primitive} == 0" + s"${eval2.value} == 0" } val javaType = ctx.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + s"${eval1.value}.$decimalMethod(${eval2.value})" } else { - s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + s"($javaType)(${eval1.value} $symbol ${eval2.value})" } s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -244,7 +244,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.primitive} = $divide; + ${ev.value} = $divide; } } """ @@ -285,20 +285,20 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.primitive}.isZero()" + s"${eval2.value}.isZero()" } else { - s"${eval2.primitive} == 0" + s"${eval2.value} == 0" } val javaType = ctx.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + s"${eval1.value}.$decimalMethod(${eval2.value})" } else { - s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + s"($javaType)(${eval1.value} $symbol ${eval2.value})" } s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -306,7 +306,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.primitive} = $remainder; + ${ev.value} = $remainder; } } """ @@ -320,7 +320,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -341,24 +341,24 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) + val compCode = ctx.genComp(dataType, eval1.value, eval2.value) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull}) { ${ev.isNull} = ${eval2.isNull}; - ${ev.primitive} = ${eval2.primitive}; + ${ev.value} = ${eval2.value}; } else if (${eval2.isNull}) { ${ev.isNull} = ${eval1.isNull}; - ${ev.primitive} = ${eval1.primitive}; + ${ev.value} = ${eval1.value}; } else { if ($compCode > 0) { - ${ev.primitive} = ${eval1.primitive}; + ${ev.value} = ${eval1.value}; } else { - ${ev.primitive} = ${eval2.primitive}; + ${ev.value} = ${eval2.value}; } } """ @@ -374,7 +374,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -395,24 +395,24 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) + val compCode = ctx.genComp(dataType, eval1.value, eval2.value) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; if (${eval1.isNull}) { ${ev.isNull} = ${eval2.isNull}; - ${ev.primitive} = ${eval2.primitive}; + ${ev.value} = ${eval2.value}; } else if (${eval2.isNull}) { ${ev.isNull} = ${eval1.isNull}; - ${ev.primitive} = ${eval1.primitive}; + ${ev.value} = ${eval1.value}; } else { if ($compCode < 0) { - ${ev.primitive} = ${eval1.primitive}; + ${ev.value} = ${eval1.value}; } else { - ${ev.primitive} = ${eval2.primitive}; + ${ev.value} = ${eval2.value}; } } """ @@ -451,9 +451,9 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s""" ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + ${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2); } else { - ${ev.primitive} = r; + ${ev.value} = r; } """ // byte and short are casted into int when add, minus, times or divide @@ -461,18 +461,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s""" ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); if (r < 0) { - ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + ${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); } else { - ${ev.primitive} = r; + ${ev.value} = r; } """ case _ => s""" ${ctx.javaType(dataType)} r = $eval1 % $eval2; if (r < 0) { - ${ev.primitive} = (r + $eval2) % $eval2; + ${ev.value} = (r + $eval2) % $eval2; } else { - ${ev.primitive} = r; + ${ev.value} = r; } """ } @@ -511,6 +511,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { private def pmod(a: Decimal, n: Decimal): Decimal = { val r = a % n - if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index c98182c96b16..9b8b6382d753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -32,6 +32,7 @@ private class CodeFormatter { private var indentLevel = 0 private val indentSize = 2 private var indentString = "" + private var currentLine = 1 private def addLine(line: String): Unit = { val indentChange = @@ -44,11 +45,13 @@ private class CodeFormatter { } else { indentString } + code.append(f"/* ${currentLine}%03d */ ") code.append(thisLineIndent) code.append(line) code.append("\n") indentLevel = newIndentLevel indentString = " " * (indentSize * newIndentLevel) + currentLine += 1 } private def addLines(code: String): CodeFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8..440c7d2fc115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -26,14 +27,11 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - - -// These classes are here to avoid issues with serialization and integration with quasiquotes. -class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] -class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -41,10 +39,10 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * @param code The sequence of statements required to evaluate the expression. * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. - * @param primitive A term for a possible primitive value of the result of the evaluation. Not - * valid if `isNull` is set to `true`. + * @param value A term for a (possibly primitive) value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String) +case class GeneratedExpressionCode(var code: String, var isNull: String, var value: String) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -83,13 +81,37 @@ class CodeGenContext { /** * Holding all the functions those will be added into generated class. */ - val addedFuntions: mutable.Map[String, String] = + val addedFunctions: mutable.Map[String, String] = mutable.Map.empty[String, String] def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFuntions += ((funcName, funcCode)) + addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two exprsesion generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState(isNull: String, value: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + + // The collection of sub-exression result resetting methods that need to be called on each row. + val subExprResetVariables = mutable.ArrayBuffer.empty[String] + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -98,6 +120,9 @@ class CodeGenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" + /** The variable name of the input row in generated code. */ + final val INPUT_ROW = "i" + private val curId = new java.util.concurrent.atomic.AtomicInteger() /** @@ -111,21 +136,22 @@ class CodeGenContext { } /** - * Returns the code to access a value in `SpecializedGetters` for a given DataType. + * Returns the specialized code to access a value from `inputRow` at `ordinal`. */ - def getValue(getter: String, dataType: DataType, ordinal: String): String = { + def getValue(input: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$getter.getUTF8String($ordinal)" - case BinaryType => s"$getter.getBinary($ordinal)" - case CalendarIntervalType => s"$getter.getInterval($ordinal)" - case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$getter.getArray($ordinal)" - case _: MapType => s"$getter.getMap($ordinal)" + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" - case _ => s"($jt)$getter.get($ordinal, null)" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" } } @@ -139,6 +165,7 @@ class CodeGenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } } @@ -171,8 +198,9 @@ class CodeGenContext { case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" case _: MapType => "MapData" - case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName - case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName case _ => "Object" } @@ -216,6 +244,7 @@ class CodeGenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case other => s"$c1.equals($c2)" } @@ -235,6 +264,49 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" + case array: ArrayType => + val elementType = array.elementType + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val compareFunc = freshName("compareArray") + val minLength = freshName("minLength") + val funcCode: String = + s""" + public int $compareFunc(ArrayData a, ArrayData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < $minLength; i++) { + boolean $isNullA = a.isNullAt(i); + boolean $isNullB = b.isNullAt(i); + if ($isNullA && $isNullB) { + // Nothing + } else if ($isNullA) { + return -1; + } else if ($isNullB) { + return 1; + } else { + ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; + ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + int comp = ${genComp(elementType, elementA, elementB)}; + if (comp != 0) { + return comp; + } + } + } + + if (lengthA < lengthB) { + return -1; + } else if (lengthA > lengthB) { + return 1; + } + return 0; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -249,10 +321,23 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ @@ -265,6 +350,114 @@ class CodeGenContext { def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + * @param expressions the codes to evaluate expressions. + */ + def splitExpressions(row: String, expressions: Seq[String]): String = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (code <- expressions) { + // We can't know how many byte code will be generated, so use the number of bytes as limit + if (blockBuilder.length > 64 * 1000) { + blocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(code) + } + blocks.append(blockBuilder.toString()) + + if (blocks.length == 1) { + // inline execution if only one block + blocks.head + } else { + val apply = freshName("apply") + val functions = blocks.zipWithIndex.map { case (body, i) => + val name = s"${apply}_$i" + val code = s""" + |private void $name(InternalRow $row) { + | $body + |} + """.stripMargin + addNewFunction(name, code) + name + } + + functions.map(name => s"$name($row);").mkString("\n") + } + } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpresses, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the exprs that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach(e => { + val expr = e.head + val isNull = freshName("isNull") + val value = freshName("value") + val fnName = freshName("evalExpr") + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow $INPUT_ROW) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; + |} + """.stripMargin + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), value, + s"$value = ${defaultValue(expr.dataType)};") + + subExprResetVariables += s"$fnName($INPUT_ROW);" + val state = SubExprEliminationState(isNull, value) + e.foreach(subExprEliminationExprs.put(_, state)) + }) + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -289,15 +482,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString + }.mkString("\n") } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString + ctx.mutableStates.map(_._3).mkString("\n") } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } /** @@ -327,9 +520,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( - classOf[PlatformDependent].getName, + classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, @@ -338,14 +533,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[ArrayData].getName, classOf[UnsafeArrayData].getName, classOf[MapData].getName, - classOf[UnsafeMapData].getName + classOf[UnsafeMapData].getName, + classOf[MutableRow].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) + + def formatted = CodeFormatter.format(code) + + logDebug({ + // Only add extra debugging info to byte code when we are going to print the source code. + evaluator.setDebuggingInformation(true, true, false) + formatted + }) + try { - evaluator.cook(code) + evaluator.cook("generated.java", code) } catch { case e: Exception => - val msg = s"failed to compile: $e\n" + CodeFormatter.format(code) + val msg = s"failed to compile: $e\n$formatted" logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 3492d2c6189e..26fb143d1e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -33,12 +33,12 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") s""" - /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + /* expression: ${this.toCommentSafeString} */ + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e4a8fc24dac2..40189f087776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.DecimalType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -28,6 +27,8 @@ abstract class BaseMutableProjection extends MutableProjection /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + * It exposes a `target` method, which is used to set the row that will be updated. + * The internal [[MutableRow]] object created internally is used only when `target` is not used. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { @@ -39,62 +40,60 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val projectionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - evaluationCode.code + + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } + val updates = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + s""" + if (this.isNull_$i) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } else { s""" - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { mutableRow.setNullAt($i); } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ + } } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline execution if codesize limit was not broken - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } + + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + public java.lang.Object generate($exprType[] expr) { + return new SpecificMutableProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificProjection($exprType[] expr) { + public SpecificMutableProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} @@ -110,12 +109,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - $projectionFuns - - public Object apply(Object _i) { - InternalRow i = (InternalRow) _i; - $projectionCalls - + public java.lang.Object apply(java.lang.Object _i) { + InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $allProjections + // copy all the results into MutableRow + $allUpdates return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 42be394c3bf5..1af7c73cd4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -74,21 +74,21 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR val isNullB = ctx.freshName("isNullB") val primitiveB = ctx.freshName("primitiveB") s""" - i = a; + ${ctx.INPUT_ROW} = a; boolean $isNullA; ${ctx.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; - $primitiveA = ${eval.primitive}; + $primitiveA = ${eval.value}; } - i = b; + ${ctx.INPUT_ROW} = b; boolean $isNullB; ${ctx.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; - $primitiveB = ${eval.primitive}; + $primitiveB = ${eval.value}; } if ($isNullA && $isNullB) { // Nothing @@ -126,9 +126,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR ${initMutableStates(ctx)} } - @Override public int compare(InternalRow a, InternalRow b) { - InternalRow i = null; // Holds current row being evaluated. + InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. $comparisons return 0; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c7e718a52642..457b4f08424a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -55,10 +55,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool ${initMutableStates(ctx)} } - @Override - public boolean eval(InternalRow i) { + public boolean eval(InternalRow ${ctx.INPUT_ROW}) { ${eval.code} - return !${eval.isNull} && ${eval.primitive}; + return !${eval.isNull} && ${eval.value}; } }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala deleted file mode 100644 index 1572b2b99ab6..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProjection extends Projection {} - -/** - * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] - * object is custom generated based on the output types of the [[Expression]] to avoid boxing of - * primitive values. - */ -object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) - - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) - - // Make Mutablility optional... - protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n ") - - val initColumns = expressions.zipWithIndex.map { - case (e, i) => - val eval = e.gen(ctx) - s""" - { - // column$i - ${eval.code} - nullBits[$i] = ${eval.isNull}; - if (!${eval.isNull}) { - c$i = ${eval.primitive}; - } - } - """ - }.mkString("\n") - - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n ") - - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n ") - - val specificAccessorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: return c$i;") - case _ => None - }.mkString("\n ") - if (cases.length > 0) { - val getter = "get" + ctx.primitiveTypeName(jt) - s""" - @Override - public $jt $getter(int i) { - if (isNullAt(i)) { - return ${ctx.defaultValue(jt)}; - } - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i - + " in $getter"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val specificMutatorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") - case _ => None - }.mkString("\n ") - if (cases.length > 0) { - val setter = "set" + ctx.primitiveTypeName(jt) - s""" - @Override - public void $setter(int i, $jt value) { - nullBits[i] = false; - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i + - " in $setter}"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = s"c$i" - val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType | TimestampType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" - case DoubleType => - s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"java.util.Arrays.hashCode($col)" - case _ => s"$col.hashCode()" - } - s"isNullAt($i) ? 0 : ($nonNull)" - } - - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") - - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (nullBits[$i] != row.nullBits[$i] || - (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { - return false; - } - """ - }.mkString("\n") - - val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n ") - - val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} - - public SpecificProjection($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} - } - - @Override - public Object apply(Object r) { - return new SpecificRow((InternalRow) r); - } - - final class SpecificRow extends ${classOf[MutableRow].getName} { - - $columns - - public SpecificRow(InternalRow i) { - $initColumns - } - - public int numFields() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public Object genericGet(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; - } - return super.equals(other); - } - - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); - } - } - } - """ - - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + - CodeFormatter.format(code)) - - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f06ffc5449e7..13634b69457a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.types.{StringType, StructType, DataType} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProjection extends Projection {} /** - * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new - * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update + * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { @@ -36,83 +39,120 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - private def genUpdater( + private def createCodeForStruct( ctx: CodeGenContext, - setter: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case struct: StructType => - val rowTerm = ctx.freshName("row") - val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => - val colTerm = ctx.freshName("col") - s""" - if ($value.isNullAt($i)) { - $rowTerm.setNullAt($i); - } else { - ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; - ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; - } - """ - }.mkString("\n") - s""" - $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); - $updates - $setter.update($ordinal, $rowTerm.copy()); - """ - case _ => - ctx.setColumn(setter, dataType, ordinal, value) + input: String, + schema: StructType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeRow") + val values = ctx.freshName("values") + // These expressions could be splitted into multiple functions + ctx.addMutableState("Object[]", values, s"this.$values = null;") + + val rowClass = classOf[GenericInternalRow].getName + + val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => + val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + s""" + if (!$tmp.isNullAt($i)) { + ${converter.code} + $values[$i] = ${converter.value}; + } + """ } + val allFields = ctx.splitExpressions(tmp, fieldWriters) + val code = s""" + final InternalRow $tmp = $input; + this.$values = new Object[${schema.length}]; + $allFields + final InternalRow $output = new $rowClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: String, + elementType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeArray") + val values = ctx.freshName("values") + val numElements = ctx.freshName("numElements") + val index = ctx.freshName("index") + val arrayClass = classOf[GenericArrayData].getName + + val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val code = s""" + final ArrayData $tmp = $input; + final int $numElements = $tmp.numElements(); + final Object[] $values = new Object[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + if (!$tmp.isNullAt($index)) { + ${elementConverter.code} + $values[$index] = ${elementConverter.value}; + } + } + final ArrayData $output = new $arrayClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: String, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeMap") + val mapClass = classOf[ArrayBasedMapData].getName + + val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val code = s""" + final MapData $tmp = $input; + ${keyConverter.code} + ${valueConverter.code} + final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def convertToSafe( + ctx: CodeGenContext, + input: String, + dataType: DataType): GeneratedExpressionCode = dataType match { + case s: StructType => createCodeForStruct(ctx, input, s) + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. + case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) + case _ => GeneratedExpressionCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val converter = convertToSafe(ctx, evaluationCode.value, e.dataType) evaluationCode.code + s""" if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { - ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + ${converter.code} + ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline it if we have only one block - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } - + val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -121,6 +161,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificSafeProjection($exprType[] expr) { expressions = expr; @@ -128,12 +169,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - $projectionFuns - - public Object apply(Object _i) { - InternalRow i = (InternalRow) _i; - $projectionCalls - + public java.lang.Object apply(java.lang.Object _i) { + InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $allExpressions return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index fc3ecf545142..68005afb21d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -32,533 +31,274 @@ import org.apache.spark.unsafe.PlatformDependent */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { - private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName - private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName - private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName - private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName - private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName - private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName - private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName - private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - - private val PlatformDependent = classOf[PlatformDependent].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { + case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) - case NullType => true case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } - def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { - case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" - case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" - case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" - case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" - case _: ArrayType => - s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" - case _: MapType => - s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" - case _ => "" - } + private val rowWriterClass = classOf[UnsafeRowWriter].getName + private val arrayWriterClass = classOf[UnsafeArrayWriter].getName - def genFieldWriter( + // TODO: if the nullability of field is correct, we can use it to save null check. + private def writeStructToBuffer( ctx: CodeGenContext, - fieldType: DataType, - ev: GeneratedExpressionCode, - primitive: String, - index: Int, - cursor: String): String = fieldType match { - case _ if ctx.isPrimitiveType(fieldType) => - s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - // make sure Decimal object has the same scale as DecimalType - if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); - } else { - $primitive.setNullAt($index); - } - """ - case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s""" - // make sure Decimal object has the same scale as DecimalType - if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); - } else { - $primitive.setNullAt($index); - } - """ - case StringType => - s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case _: StructType => - s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case _: ArrayType => - s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case _: MapType => - s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") - } - - /** - * Generates the code to create an [[UnsafeRow]] object based on the input expressions. - * @param ctx context for code generation - * @param ev specifies the name of the variable for the output [[UnsafeRow]] object - * @param expressions input expressions - * @return generated code to put the expression output into an [[UnsafeRow]] - */ - def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) - : String = { - - val ret = ev.primitive - ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - val numBytes = ctx.freshName("numBytes") - - val exprs = expressions.map { e => e.dataType match { - case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) - case _ => e.gen(ctx) - }} - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { - case (e, i) => genAdditionalSize(e.dataType, exprs(i)) - }.mkString("") - - val writers = expressions.zipWithIndex.map { case (e, i) => - val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) - s"""if (${exprs(i).isNull}) { - $ret.setNullAt($i); - } else { - $update; - }""" - }.mkString("\n ") + input: String, + fieldTypes: Seq[DataType], + bufferHolder: String): String = { + val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => + val fieldName = ctx.freshName("fieldName") + val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" + val isNull = s"$input.isNullAt($i)" + GeneratedExpressionCode(code, isNull, fieldName) + } s""" - $allExprs - int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; + if ($input instanceof UnsafeRow) { + ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} + } else { + ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} } - - $ret.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, - $numBytes); - int $cursor = $fixedSize; - - $writers - boolean ${ev.isNull} = false; - """ + """ } - /** - * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. - * - * This function also handles nested structs by recursively generating the code to do conversion. - * - * @param ctx code generation context - * @param input the input struct, identified by a [[GeneratedExpressionCode]] - * @param schema schema of the struct field - */ - // TODO: refactor createCode and this function to reduce code duplication. - private def createCodeForStruct( + private def writeExpressionsToBuffer( ctx: CodeGenContext, - input: GeneratedExpressionCode, - schema: StructType): GeneratedExpressionCode = { - - val isNull = input.isNull - val primitive = ctx.freshName("structConvert") - ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - - val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { - case (dt, i) => dt match { - case st: StructType => - val nestedStructEv = GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) - createCodeForStruct(ctx, nestedStructEv, st) - case _ => - GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) + row: String, + inputs: Seq[GeneratedExpressionCode], + inputTypes: Seq[DataType], + bufferHolder: String): String = { + val rowWriter = ctx.freshName("rowWriter") + ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + + val writeFields = inputs.zip(inputTypes).zipWithIndex.map { + case ((input, dataType), index) => + val dt = dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other } - } - val allExprs = exprs.map(_.code).mkString("\n") + val tmpCursor = ctx.freshName("tmpCursor") - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") + val setNull = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType with precision larger than 18. + s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" + case _ => s"$rowWriter.setNullAt($index);" + } - val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) - s""" - if (${exprs(i).isNull}) { - $primitive.setNullAt($i); + val writeField = dt match { + case t: StructType => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + """ + + case a @ ArrayType(et, _) => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); + """ + + case m @ MapType(kt, vt, _) => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); + """ + + case _ if ctx.isPrimitiveType(dt) => + s""" + $rowWriter.write($index, ${input.value}); + """ + + case t: DecimalType => + s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" + + case NullType => "" + + case _ => s"$rowWriter.write($index, ${input.value});" + } + + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} } else { - $update; + ${writeField.trim} } """ - }.mkString("\n ") - - // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, - // just copy the bytes directly into our buffer space without running any conversion. - // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from - // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. - val tmp = ctx.freshName("tmp") - val numBytes = ctx.freshName("numBytes") - val code = s""" - |${input.code} - |if (!${input.isNull}) { - | Object $tmp = (Object) ${input.primitive}; - | if ($tmp instanceof UnsafeRow) { - | $primitive = (UnsafeRow) $tmp; - | } else { - | $allExprs - | - | int $numBytes = $fixedSize $additionalSize; - | if ($numBytes > $buffer.length) { - | $buffer = new byte[$numBytes]; - | } - | - | $primitive.pointTo( - | $buffer, - | $PlatformDependent.BYTE_ARRAY_OFFSET, - | ${exprs.size}, - | $numBytes); - | int $cursor = $fixedSize; - | - | $writers - | } - |} - """.stripMargin - - GeneratedExpressionCode(code, isNull, primitive) - } - - /** - * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. - * - * @param ctx code generation context - * @param inputs could be the codes for expressions or input struct fields. - * @param inputTypes types of the inputs - */ - private def createCodeForStruct2( - ctx: CodeGenContext, - inputs: Seq[GeneratedExpressionCode], - inputTypes: Seq[DataType]): GeneratedExpressionCode = { - - val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val numBytes = ctx.freshName("numBytes") - val cursor = ctx.freshName("cursor") - - val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => - createConvertCode(ctx, input, dt) } - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - s""" - if (${ev.isNull}) { - $output.setNullAt($i); - } else { - $update; - } - """ - }.mkString("\n") - - val code = s""" - ${convertedFields.map(_.code).mkString("\n")} - - final int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, - $numBytes); - - int $cursor = $fixedSize; - - $fieldWriters - """ - GeneratedExpressionCode(code, "false", output) - } - - private def getWriter(dt: DataType) = dt match { - case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName - case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName - case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName - case _: StructType => classOf[UnsafeWriters.StructWriter].getName - case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName - case _: MapType => classOf[UnsafeWriters.MapWriter].getName - case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName + s""" + $rowWriter.initialize($bufferHolder, ${inputs.length}); + ${ctx.splitExpressions(row, writeFields)} + """.trim } - private def createCodeForArray( + // TODO: if the nullability of array element is correct, we can use it to save null check. + private def writeArrayToBuffer( ctx: CodeGenContext, - input: GeneratedExpressionCode, - elementType: DataType): GeneratedExpressionCode = { - val output = ctx.freshName("convertedArray") - ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") + input: String, + elementType: DataType, + bufferHolder: String): String = { + val arrayWriter = ctx.freshName("arrayWriter") + ctx.addMutableState(arrayWriterClass, arrayWriter, + s"this.$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") - val fixedSize = ctx.freshName("fixedSize") - val numBytes = ctx.freshName("numBytes") - val elements = ctx.freshName("elements") - val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") + val element = ctx.freshName("element") - val element = GeneratedExpressionCode( - code = "", - isNull = s"$tmp.isNullAt($index)", - primitive = s"${ctx.getValue(tmp, elementType, index)}" - ) - val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) - - // go through the input array to calculate how many bytes we need. - val calculateNumBytes = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - $numBytes += $elementSize * $numElements; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - $numBytes += 8 * $numElements; - """ - case _ => - val writer = getWriter(elementType) - val elementSize = s"$writer.getSize($elements[$index])" - val unsafeType = elementType match { - case _: StructType => "UnsafeRow" - case _: ArrayType => "UnsafeArrayData" - case _: MapType => "UnsafeMapData" - case _ => ctx.javaType(elementType) - } - val copy = elementType match { - // We reuse the buffer during conversion, need copy it before process next element. - case _: StructType | _: ArrayType | _: MapType => ".copy()" - case _ => "" - } + val et = elementType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } - s""" - final $unsafeType[] $elements = new $unsafeType[$numElements]; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if (!${convertedElement.isNull}) { - $elements[$index] = ${convertedElement.primitive}$copy; - $numBytes += $elementSize; - } - } - """ + val jt = ctx.javaType(et) + + val fixedElementSize = et match { + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 + case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ => 0 } - val writeElement = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => - // Should we do word align? - val elementSize = elementType.defaultSize + val writeElement = et match { + case t: StructType => s""" - $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, - ${convertedElement.primitive}); - $cursor += $elementSize; + $arrayWriter.setOffset($index); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + + case a @ ArrayType(et, _) => s""" - $PlatformDependent.UNSAFE.putLong( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, - ${convertedElement.primitive}.toUnscaledLong()); - $cursor += 8; + $arrayWriter.setOffset($index); + ${writeArrayToBuffer(ctx, element, et, bufferHolder)} """ - case _ => - val writer = getWriter(elementType) + + case m @ MapType(kt, vt, _) => s""" - $cursor += $writer.write( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, - $elements[$index]); + $arrayWriter.setOffset($index); + ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - } - val checkNull = elementType match { - case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" - case t: DecimalType => s"$elements[$index] == null" + - s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" - case _ => s"$elements[$index] == null" - } + case t: DecimalType => + s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - val code = s""" - ${input.code} - final boolean $outputIsNull = ${input.isNull}; - if (!$outputIsNull) { - final ArrayData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeArrayData) { - $output = (UnsafeArrayData) $tmp; - } else { - final int $numElements = $tmp.numElements(); - final int $fixedSize = 4 * $numElements; - int $numBytes = $fixedSize; - - $calculateNumBytes - - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } + case NullType => "" - int $cursor = $fixedSize; - for (int $index = 0; $index < $numElements; $index++) { - if ($checkNull) { - // If element is null, write the negative value address into offset region. - $PlatformDependent.UNSAFE.putInt( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - -$cursor); - } else { - $PlatformDependent.UNSAFE.putInt( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - $cursor); - - $writeElement - } - } + case _ => s"$arrayWriter.write($index, $element);" + } - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - $numElements, - $numBytes); + s""" + if ($input instanceof UnsafeArrayData) { + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} + } else { + final int $numElements = $input.numElements(); + $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNullAt($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } } } - """ - GeneratedExpressionCode(code, outputIsNull, output) + """ } - private def createCodeForMap( + // TODO: if the nullability of value element is correct, we can use it to save null check. + private def writeMapToBuffer( ctx: CodeGenContext, - input: GeneratedExpressionCode, + input: String, keyType: DataType, - valueType: DataType): GeneratedExpressionCode = { - val output = ctx.freshName("convertedMap") - val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") - - val keyArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.keyArray()" - ) - val valueArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.valueArray()" - ) - val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) - val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + valueType: DataType, + bufferHolder: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val tmpCursor = ctx.freshName("tmpCursor") - val code = s""" - ${input.code} - final boolean $outputIsNull = ${input.isNull}; - UnsafeMapData $output = null; - if (!$outputIsNull) { - final MapData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeMapData) { - $output = (UnsafeMapData) $tmp; - } else { - ${convertedKeys.code} - ${convertedValues.code} - $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive}); - } + + // Writes out unsafe map according to the format described in `UnsafeMapData`. + s""" + if ($input instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + } else { + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + + // preserve 4 bytes to write the key array numBytes later. + $bufferHolder.grow(4); + $bufferHolder.cursor += 4; + + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; + + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + // Write the numBytes of key array into the first 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} } - """ - GeneratedExpressionCode(code, outputIsNull, output) + """ } /** - * Generates the java code to convert a data to its unsafe version. + * If the input is already in unsafe format, we don't need to go through all elements/fields, + * we can directly write it. */ - private def createConvertCode( + private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + val sizeInBytes = ctx.freshName("sizeInBytes") + s""" + final int $sizeInBytes = $input.getSizeInBytes(); + // grow the global buffer before writing data. + $bufferHolder.grow($sizeInBytes); + $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); + $bufferHolder.cursor += $sizeInBytes; + """ + } + + def createCode( ctx: CodeGenContext, - input: GeneratedExpressionCode, - dataType: DataType): GeneratedExpressionCode = dataType match { - case t: StructType => - val output = ctx.freshName("convertedStruct") - val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") - val fieldTypes = t.fields.map(_.dataType) - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val getFieldCode = ctx.getValue(tmp, dt, i.toString) - val fieldIsNull = s"$tmp.isNullAt($i)" - GeneratedExpressionCode("", fieldIsNull, getFieldCode) - } - val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes) - val code = s""" - ${input.code} - UnsafeRow $output = null; - final boolean $outputIsNull = ${input.isNull}; - if (!$outputIsNull) { - final InternalRow $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeRow) { - $output = (UnsafeRow) $tmp; - } else { - ${converter.code} - $output = ${converter.primitive}; - } - } - """ - GeneratedExpressionCode(code, outputIsNull, output) + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val exprTypes = expressions.map(_.dataType) + + val result = ctx.freshName("result") + ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") + val bufferHolder = ctx.freshName("bufferHolder") + val holderClass = classOf[BufferHolder].getName + ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + // Reset the subexpression values for each row. + val subexprReset = ctx.subExprResetVariables.mkString("\n") - case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt) + val code = + s""" + $bufferHolder.reset(); + $subexprReset + ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - case _ => input + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + """ + GeneratedExpressionCode(code, "false", result) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -567,14 +307,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } + protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + create(expressions, subexpressionEliminationEnabled = false) + } - val exprEvals = expressions.map(e => e.gen(ctx)) - val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType)) + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -583,6 +333,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { @@ -591,13 +342,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } - public UnsafeRow apply(InternalRow i) { - ${eval.code} - return ${eval.primitive}; + public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + ${eval.code.trim} + return ${eval.value}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 30b51dd83fa9..da602d9b4bce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -52,9 +52,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val offset = PlatformDependent.BYTE_ARRAY_OFFSET - val getLong = "PlatformDependent.UNSAFE.getLong" - val putLong = "PlatformDependent.UNSAFE.putLong" + val offset = Platform.BYTE_ARRAY_OFFSET + val getLong = "Platform.getLong" + val putLong = "Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -96,7 +96,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U var cursor = offset + outputBitsetWords * 8 val copyFixedLengthRow1 = s""" |// Copy fixed length data for row1 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${bitset1Words * 8}, | buf, $cursor, | ${schema1.size * 8}); @@ -106,7 +106,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy fixed length portion from row 2 ----------------------- // val copyFixedLengthRow2 = s""" |// Copy fixed length data for row2 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${bitset2Words * 8}, | buf, $cursor, | ${schema2.size * 8}); @@ -118,7 +118,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, | numBytesVariableRow1); @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, | numBytesVariableRow2); @@ -155,11 +155,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } - }.mkString + }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 606fecbe06e4..41128fe389d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.rules import org.apache.spark.util.Utils @@ -40,10 +39,8 @@ package object codegen { } /** - * :: DeveloperApi :: * Dumps the bytecode from a class to the screen using javap. */ - @DeveloperApi object DumpByteCode { import scala.sys.process._ val dumpDirectory = Utils.createTempDir() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ccb56578f79..741ad1f3efd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ /** @@ -36,7 +36,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).numElements();") + nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -68,6 +68,8 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -89,6 +91,8 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -109,9 +113,80 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def nullSafeEval(array: Any, ascending: Any): Any = { val elementType = base.dataType.asInstanceOf[ArrayType].elementType val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + if (elementType != NullType) { + java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + } new GenericArrayData(data.asInstanceOf[Array[Any]]) } override def prettyName: String = "sort_array" } + +/** + * Checks if the array (left) has the element (right) + */ +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true + } + ) + if (hasNull) { + null + } else { + false + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = ctx.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + } + """ + }) + } + + override def prettyName: String = "array_contains" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a145dfb4bbf0..72cc89c8be91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") s""" final boolean ${ev.isNull} = false; - final Object[] values = new Object[${children.size}]; + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - values[$i] = null; + $values[$i] = null; } else { - values[$i] = ${eval.primitive}; + $values[$i] = ${eval.value}; } """ }.mkString("\n") + - s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" + s"final ArrayData ${ev.value} = new $arrayClass($values);" } override def prettyName: String = "array" @@ -75,8 +76,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -96,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.value}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.value} = new $rowClass($values);" } override def prettyName: String = "struct" @@ -124,6 +125,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { */ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -163,21 +172,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + final Object[] $values = new Object[${valExprs.size}]; """ + valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.value}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.value} = new $rowClass($values);" } override def prettyName: String = "named_struct" @@ -208,10 +219,15 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - GenerateUnsafeProjection.createCode(ctx, ev, children) + val eval = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = eval.isNull + ev.value = eval.value + eval.code } override def prettyName: String = "struct_unsafe" @@ -243,10 +259,15 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) + ev.isNull = eval.isNull + ev.value = eval.value + eval.code } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9927da21b052..58f6a7ec8a5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -96,16 +97,21 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] + + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -113,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -175,7 +181,7 @@ case class GetArrayStructFields( } } } - ${ev.primitive} = new $arrayClass(values); + ${ev.value} = new $arrayClass(values); """ }) } @@ -219,7 +225,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) @@ -295,7 +301,7 @@ case class GetMapValue(child: Expression, key: Expression) } if ($found) { - ${ev.primitive} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; + ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; } else { ${ev.isNull} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala similarity index 87% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 961b1d861680..40b1eec63e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} +import org.apache.spark.sql.types._ case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -60,15 +60,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" ${condEval.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.primitive}) { + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.value}) { ${trueEval.code} ${ev.isNull} = ${trueEval.isNull}; - ${ev.primitive} = ${trueEval.primitive}; + ${ev.value} = ${trueEval.value}; } else { ${falseEval.code} ${ev.isNull} = ${falseEval.isNull}; - ${ev.primitive} = ${falseEval.primitive}; + ${ev.value} = ${falseEval.value}; } """ } @@ -166,11 +166,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { s""" if (!$got) { ${cond.code} - if (!${cond.isNull} && ${cond.primitive}) { + if (!${cond.isNull} && ${cond.value}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; + ${ev.value} = ${res.value}; } } """ @@ -182,7 +182,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { if (!$got) { ${res.code} ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; + ${ev.value} = ${res.value}; } """ } else { @@ -192,7 +192,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { s""" boolean $got = false; boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $cases $other """ @@ -267,11 +267,11 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" if (!$got) { ${cond.code} - if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { + if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.value, cond.value)}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; + ${ev.value} = ${res.value}; } } """ @@ -283,7 +283,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW if (!$got) { ${res.code} ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; + ${ev.value} = ${res.value}; } """ } else { @@ -293,7 +293,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" boolean $got = false; boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; ${keyEval.code} if (!${keyEval.isNull}) { $cases @@ -319,7 +319,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { @@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.primitive} = ${evalChildren(i).primitive}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -374,7 +377,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { @@ -403,19 +406,23 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.primitive} = ${evalChildren(i).primitive}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala similarity index 91% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32dc9b76821b..03c39f8404e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -82,7 +82,7 @@ case class DateAdd(startDate: Expression, days: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.primitive} = $sd + $d;""" + s"""${ev.value} = $sd + $d;""" }) } } @@ -105,7 +105,7 @@ case class DateSub(startDate: Expression, days: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.primitive} = $sd - $d;""" + s"""${ev.value} = $sd - $d;""" }) } } @@ -269,7 +269,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa """) s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); + ${ev.value} = $c.get($cal.WEEK_OF_YEAR); """ }) } @@ -299,7 +299,20 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx } /** - * Converts time string with given pattern + * Converts time string with given pattern. + * Deterministic version of [[UnixTimestamp]], must have at least one parameter. + */ +case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } +} + +/** + * Converts time string with given pattern. * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) * to Unix time stamp (in seconds), returns null if fail. * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. @@ -308,9 +321,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ -case class UnixTimestamp(timeExp: Expression, format: Expression) - extends BinaryExpression with ExpectsInputTypes { - +case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -321,6 +332,9 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) def this() = { this(CurrentTimestamp()) } +} + +abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, DateType, TimestampType), StringType) @@ -347,7 +361,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) null } case StringType => - val f = format.eval(input) + val f = right.eval(input) if (f == null) { null } else { @@ -368,19 +382,19 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) if (fString == null) { s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """ } else { val eval1 = left.gen(ctx) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { try { $sdf $formatter = new $sdf("$fString"); - ${ev.primitive} = - $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + ${ev.value} = + $formatter.parse(${eval1.value}.toString()).getTime() / 1000L; } catch (java.lang.Throwable e) { ${ev.isNull} = true; } @@ -392,7 +406,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.primitive} = + ${ev.value} = (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; } catch (java.lang.Throwable e) { ${ev.isNull} = true; @@ -404,9 +418,9 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = ${eval1.primitive} / 1000000L; + ${ev.value} = ${eval1.value} / 1000000L; } """ case DateType => @@ -415,9 +429,9 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + ${ev.value} = $dtu.daysToMillis(${eval1.value}) / 1000L; } """ } @@ -477,18 +491,18 @@ case class FromUnixTime(sec: Expression, format: Expression) if (constFormat == null) { s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """ } else { val t = left.gen(ctx) s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( - new java.util.Date(${t.primitive} * 1000L))); + ${ev.value} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.value} * 1000L))); } catch (java.lang.Throwable e) { ${ev.isNull} = true; } @@ -499,7 +513,7 @@ case class FromUnixTime(sec: Expression, format: Expression) nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + ${ev.value} = UTF8String.fromString((new $sdf($f.toString())).format( new java.util.Date($seconds * 1000L))); } catch (java.lang.Throwable e) { ${ev.isNull} = true; @@ -571,7 +585,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } else { val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) s""" - |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); + |${ev.value} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); """.stripMargin } } else { @@ -580,7 +594,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) |if ($dayOfWeekTerm == -1) { | ${ev.isNull} = true; |} else { - | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); + | ${ev.value} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); |} """.stripMargin } @@ -640,7 +654,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (tz == null) { s""" |boolean ${ev.isNull} = true; - |long ${ev.primitive} = 0; + |long ${ev.value} = 0; """.stripMargin } else { val tzTerm = ctx.freshName("tz") @@ -650,10 +664,10 @@ case class FromUTCTimestamp(left: Expression, right: Expression) s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; - |long ${ev.primitive} = 0; + |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.primitive} = ${eval.primitive} + - | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; + | ${ev.value} = ${eval.value} + + | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} """.stripMargin } @@ -765,7 +779,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (tz == null) { s""" |boolean ${ev.isNull} = true; - |long ${ev.primitive} = 0; + |long ${ev.value} = 0; """.stripMargin } else { val tzTerm = ctx.freshName("tz") @@ -775,10 +789,10 @@ case class ToUTCTimestamp(left: Expression, right: Expression) s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; - |long ${ev.primitive} = 0; + |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.primitive} = ${eval.primitive} - - | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; + | ${ev.value} = ${eval.value} - + | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} """.stripMargin } @@ -849,16 +863,16 @@ case class TruncDate(date: Expression, format: Expression) if (truncLevel == -1) { s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """ } else { val d = date.gen(ctx) s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); + ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); } """ } @@ -870,7 +884,7 @@ case class TruncDate(date: Expression, format: Expression) if ($form == -1) { ${ev.isNull} = true; } else { - ${ev.primitive} = $dtu.truncDate($dateVal, $form); + ${ev.value} = $dtu.truncDate($dateVal, $form); } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala similarity index 69% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index adb33e4c8d4a..78f6631e4647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -55,8 +55,8 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" - ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); - ${ev.isNull} = ${ev.primitive} == null; + ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); + ${ev.isNull} = ${ev.value} == null; """ }) } @@ -66,10 +66,44 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un * An expression used to wrap the children when promote the precision of DecimalType to avoid * promote multiple times. */ -case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { +case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "change_decimal_precision" + override def prettyName: String = "promote_precision" +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not, returns null if not. + */ +case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Decimal].clone() + if (d.changePrecision(dataType.precision, dataType.scale)) { + d + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + val tmp = ctx.freshName("tmp") + s""" + | Decimal $tmp = $eval.clone(); + | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { + | ${ev.value} = $tmp; + | } else { + | ${ev.isNull} = true; + | } + """.stripMargin + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d474853355e5..894a0730d1c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -54,7 +53,7 @@ trait Generator extends Expression { * The output element data types in structure of Seq[(DataType, Nullable)] * TODO we probably need to add more information like metadata etc. */ - def elementTypes: Seq[(DataType, Boolean)] + def elementTypes: Seq[(DataType, Boolean, String)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: InternalRow): TraversableOnce[InternalRow] @@ -70,7 +69,7 @@ trait Generator extends Expression { * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - elementTypes: Seq[(DataType, Boolean)], + elementTypes: Seq[(DataType, Boolean, String)], function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator with CodegenFallback { @@ -113,9 +112,11 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } } - override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { - case ArrayType(et, containsNull) => (et, containsNull) :: Nil - case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil + // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) + override def elementTypes: Seq[(DataType, Boolean, String)] = child.dataType match { + case ArrayType(et, containsNull) => (et, containsNull, "col") :: Nil + case MapType(kt, vt, valueContainsNull) => + (kt, false, "key") :: (vt, valueContainsNull, "value") :: Nil } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala new file mode 100644 index 000000000000..4991b9cb54e5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -0,0 +1,460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.{StringWriter, ByteArrayOutputStream} + +import com.fasterxml.jackson.core._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +import scala.util.parsing.combinator.RegexParsers + +private[this] sealed trait PathInstruction +private[this] object PathInstruction { + private[expressions] case object Subscript extends PathInstruction + private[expressions] case object Wildcard extends PathInstruction + private[expressions] case object Key extends PathInstruction + private[expressions] case class Index(index: Long) extends PathInstruction + private[expressions] case class Named(name: String) extends PathInstruction +} + +private[this] sealed trait WriteStyle +private[this] object WriteStyle { + private[expressions] case object RawStyle extends WriteStyle + private[expressions] case object QuotedStyle extends WriteStyle + private[expressions] case object FlattenStyle extends WriteStyle +} + +private[this] object JsonPathParser extends RegexParsers { + import PathInstruction._ + + def root: Parser[Char] = '$' + + def long: Parser[Long] = "\\d+".r ^? { + case x => x.toLong + } + + // parse `[*]` and `[123]` subscripts + def subscript: Parser[List[PathInstruction]] = + for { + operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index) <~ ']' + } yield { + Subscript :: operand :: Nil + } + + // parse `.name` or `['name']` child expressions + def named: Parser[List[PathInstruction]] = + for { + name <- '.' ~> "[^\\.\\[]+".r | "[\\'" ~> "[^\\'\\?]+" <~ "\\']" + } yield { + Key :: Named(name) :: Nil + } + + // child wildcards: `..`, `.*` or `['*']` + def wildcard: Parser[List[PathInstruction]] = + (".*" | "['*']") ^^^ List(Wildcard) + + def node: Parser[List[PathInstruction]] = + wildcard | + named | + subscript + + val expression: Parser[List[PathInstruction]] = { + phrase(root ~> rep(node) ^^ (x => x.flatten)) + } + + def parse(str: String): Option[List[PathInstruction]] = { + this.parseAll(expression, str) match { + case Success(result, _) => + Some(result) + + case NoSuccess(msg, next) => + None + } + } +} + +private[this] object SharedFactory { + val jsonFactory = new JsonFactory() + + // Enabled for Hive compatibility + jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) +} + +/** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + */ +case class GetJsonObject(json: Expression, path: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + import SharedFactory._ + import PathInstruction._ + import WriteStyle._ + import com.fasterxml.jackson.core.JsonToken._ + + override def left: Expression = json + override def right: Expression = path + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def dataType: DataType = StringType + override def prettyName: String = "get_json_object" + + @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val jsonStr = json.eval(input).asInstanceOf[UTF8String] + if (jsonStr == null) { + return null + } + + val parsed = if (path.foldable) { + parsedPath + } else { + parsePath(path.eval(input).asInstanceOf[UTF8String]) + } + + if (parsed.isDefined) { + try { + Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + val output = new ByteArrayOutputStream() + val matched = Utils.tryWithResource( + jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => + parser.nextToken() + evaluatePath(parser, generator, RawStyle, parsed.get) + } + if (matched) { + UTF8String.fromBytes(output.toByteArray) + } else { + null + } + } + } catch { + case _: JsonProcessingException => null + } + } else { + null + } + } + + private def parsePath(path: UTF8String): Option[List[PathInstruction]] = { + if (path != null) { + JsonPathParser.parse(path.toString) + } else { + None + } + } + + // advance to the desired array index, assumes to start at the START_ARRAY token + private def arrayIndex(p: JsonParser, f: () => Boolean): Long => Boolean = { + case _ if p.getCurrentToken == END_ARRAY => + // terminate, nothing has been written + false + + case 0 => + // we've reached the desired index + val dirty = f() + + while (p.nextToken() != END_ARRAY) { + // advance the token stream to the end of the array + p.skipChildren() + } + + dirty + + case i if i > 0 => + // skip this token and evaluate the next + p.skipChildren() + p.nextToken() + arrayIndex(p, f)(i - 1) + } + + /** + * Evaluate a list of JsonPath instructions, returning a bool that indicates if any leaf nodes + * have been written to the generator + */ + private def evaluatePath( + p: JsonParser, + g: JsonGenerator, + style: WriteStyle, + path: List[PathInstruction]): Boolean = { + (p.getCurrentToken, path) match { + case (VALUE_STRING, Nil) if style == RawStyle => + // there is no array wildcard or slice parent, emit this string without quotes + if (p.hasTextCharacters) { + g.writeRaw(p.getTextCharacters, p.getTextOffset, p.getTextLength) + } else { + g.writeRaw(p.getText) + } + true + + case (START_ARRAY, Nil) if style == FlattenStyle => + // flatten this array into the parent + var dirty = false + while (p.nextToken() != END_ARRAY) { + dirty |= evaluatePath(p, g, style, Nil) + } + dirty + + case (_, Nil) => + // general case: just copy the child tree verbatim + g.copyCurrentStructure(p) + true + + case (START_OBJECT, Key :: xs) => + var dirty = false + while (p.nextToken() != END_OBJECT) { + if (dirty) { + // once a match has been found we can skip other fields + p.skipChildren() + } else { + dirty = evaluatePath(p, g, style, xs) + } + } + dirty + + case (START_ARRAY, Subscript :: Wildcard :: Subscript :: Wildcard :: xs) => + // special handling for the non-structure preserving double wildcard behavior in Hive + var dirty = false + g.writeStartArray() + while (p.nextToken() != END_ARRAY) { + dirty |= evaluatePath(p, g, FlattenStyle, xs) + } + g.writeEndArray() + dirty + + case (START_ARRAY, Subscript :: Wildcard :: xs) if style != QuotedStyle => + // retain Flatten, otherwise use Quoted... cannot use Raw within an array + val nextStyle = style match { + case RawStyle => QuotedStyle + case FlattenStyle => FlattenStyle + case QuotedStyle => throw new IllegalStateException() + } + + // temporarily buffer child matches, the emitted json will need to be + // modified slightly if there is only a single element written + val buffer = new StringWriter() + + var dirty = 0 + Utils.tryWithResource(jsonFactory.createGenerator(buffer)) { flattenGenerator => + flattenGenerator.writeStartArray() + + while (p.nextToken() != END_ARRAY) { + // track the number of array elements and only emit an outer array if + // we've written more than one element, this matches Hive's behavior + dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + } + flattenGenerator.writeEndArray() + } + + val buf = buffer.getBuffer + if (dirty > 1) { + g.writeRawValue(buf.toString) + } else if (dirty == 1) { + // remove outer array tokens + g.writeRawValue(buf.substring(1, buf.length()-1)) + } // else do not write anything + + dirty > 0 + + case (START_ARRAY, Subscript :: Wildcard :: xs) => + var dirty = false + g.writeStartArray() + while (p.nextToken() != END_ARRAY) { + // wildcards can have multiple matches, continually update the dirty count + dirty |= evaluatePath(p, g, QuotedStyle, xs) + } + g.writeEndArray() + + dirty + + case (START_ARRAY, Subscript :: Index(idx) :: (xs@Subscript :: Wildcard :: _)) => + p.nextToken() + // we're going to have 1 or more results, switch to QuotedStyle + arrayIndex(p, () => evaluatePath(p, g, QuotedStyle, xs))(idx) + + case (START_ARRAY, Subscript :: Index(idx) :: xs) => + p.nextToken() + arrayIndex(p, () => evaluatePath(p, g, style, xs))(idx) + + case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => + // exact field match + if (p.nextToken() != JsonToken.VALUE_NULL) { + evaluatePath(p, g, style, xs) + } else { + false + } + + case (FIELD_NAME, Wildcard :: xs) => + // wildcard field match + p.nextToken() + evaluatePath(p, g, style, xs) + + case _ => + p.skipChildren() + false + } + } +} + +case class JsonTuple(children: Seq[Expression]) + extends Generator with CodegenFallback { + + import SharedFactory._ + + override def nullable: Boolean = { + // a row is always returned + false + } + + // if processing fails this shared value will be returned + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil + + // the json body is the first child + @transient private lazy val jsonExpr: Expression = children.head + + // the fields to query are the remaining children + @transient private lazy val fieldExpressions: Seq[Expression] = children.tail + + // eagerly evaluate any foldable the field names + @transient private lazy val foldableFieldNames: IndexedSeq[String] = { + fieldExpressions.map { + case expr if expr.foldable => expr.eval().asInstanceOf[UTF8String].toString + case _ => null + }.toIndexedSeq + } + + // and count the number of foldable fields, we'll use this later to optimize evaluation + @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) + + override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { + case (_, idx) => (StringType, true, s"c$idx") + } + + override def prettyName: String = "json_tuple" + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least two arguments") + } else if (children.forall(child => StringType.acceptsType(child.dataType))) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires that all arguments are strings") + } + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val json = jsonExpr.eval(input).asInstanceOf[UTF8String] + if (json == null) { + return nullRow + } + + try { + Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { + parser => parseRow(parser, input) + } + } catch { + case _: JsonProcessingException => + nullRow + } + } + + private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { + // only objects are supported + if (parser.nextToken() != JsonToken.START_OBJECT) { + return nullRow + } + + // evaluate the field names as String rather than UTF8String to + // optimize lookups from the json token, which is also a String + val fieldNames = if (constantFields == fieldExpressions.length) { + // typically the user will provide the field names as foldable expressions + // so we can use the cached copy + foldableFieldNames + } else if (constantFields == 0) { + // none are foldable so all field names need to be evaluated from the input row + fieldExpressions.map(_.eval(input).asInstanceOf[UTF8String].toString) + } else { + // if there is a mix of constant and non-constant expressions + // prefer the cached copy when available + foldableFieldNames.zip(fieldExpressions).map { + case (null, expr) => expr.eval(input).asInstanceOf[UTF8String].toString + case (fieldName, _) => fieldName + } + } + + val row = Array.ofDim[Any](fieldNames.length) + + // start reading through the token stream, looking for any requested field names + while (parser.nextToken() != JsonToken.END_OBJECT) { + if (parser.getCurrentToken == JsonToken.FIELD_NAME) { + // check to see if this field is desired in the output + val idx = fieldNames.indexOf(parser.getCurrentName) + if (idx >= 0) { + // it is, copy the child tree to the correct location in the output row + val output = new ByteArrayOutputStream() + + // write the output directly to UTF8 encoded byte array + if (parser.nextToken() != JsonToken.VALUE_NULL) { + Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { + generator => copyCurrentStructure(generator, parser) + } + + row(idx) = UTF8String.fromBytes(output.toByteArray) + } + } + } + + // always skip children, it's cheap enough to do even if copyCurrentStructure was called + parser.skipChildren() + } + + new GenericInternalRow(row) :: Nil + } + + private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { + parser.getCurrentToken match { + // if the user requests a string field it needs to be returned without enclosing + // quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write + case JsonToken.VALUE_STRING if parser.hasTextCharacters => + // slight optimization to avoid allocating a String instance, though the characters + // still have to be decoded... Jackson doesn't have a way to access the raw bytes + generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength) + + case JsonToken.VALUE_STRING => + // the normal String case, pass it through to the output without enclosing quotes + generator.writeRaw(parser.getText) + + case JsonToken.VALUE_NULL => + // a special case that needs to be handled outside of this method. + // if a requested field is null, the result must be null. the easiest + // way to achieve this is just by ignoring null tokens entirely + throw new IllegalStateException("Do not attempt to copy a null field") + + case _ => + // handle other types including objects, arrays, booleans and numbers + generator.copyCurrentStructure(parser) + } + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 34bad23802ba..68ec688c99f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -36,21 +35,55 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) - case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: java.math.BigDecimal => + Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) + case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + + /** + * Create a literal with default value for given DataType + */ + def default(dataType: DataType): Literal = dataType match { + case NullType => create(null, NullType) + case BooleanType => Literal(false) + case ByteType => Literal(0.toByte) + case ShortType => Literal(0.toShort) + case IntegerType => Literal(0) + case LongType => Literal(0L) + case FloatType => Literal(0.0f) + case DoubleType => Literal(0.0) + case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) + case DateType => create(0, DateType) + case TimestampType => create(0L, TimestampType) + case StringType => Literal("") + case BinaryType => Literal("".getBytes) + case CalendarIntervalType => Literal(new CalendarInterval(0, 0)) + case arr: ArrayType => create(Array(), arr) + case map: MapType => create(Map(), map) + case struct: StructType => + create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) + case other => + throw new RuntimeException(s"no default for type $dataType") + } } /** @@ -96,12 +129,12 @@ case class Literal protected (value: Any, dataType: DataType) // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};" + s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};" } else { dataType match { case BooleanType => ev.isNull = "false" - ev.primitive = value.toString + ev.value = value.toString "" case FloatType => val v = value.asInstanceOf[Float] @@ -109,7 +142,7 @@ case class Literal protected (value: Any, dataType: DataType) super.genCode(ctx, ev) } else { ev.isNull = "false" - ev.primitive = s"${value}f" + ev.value = s"${value}f" "" } case DoubleType => @@ -118,20 +151,20 @@ case class Literal protected (value: Any, dataType: DataType) super.genCode(ctx, ev) } else { ev.isNull = "false" - ev.primitive = s"${value}D" + ev.value = s"${value}D" "" } case ByteType | ShortType => ev.isNull = "false" - ev.primitive = s"(${ctx.javaType(dataType)})$value" + ev.value = s"(${ctx.javaType(dataType)})$value" "" case IntegerType | DateType => ev.isNull = "false" - ev.primitive = value.toString + ev.value = value.toString "" case TimestampType | LongType => ev.isNull = "false" - ev.primitive = s"${value}L" + ev.value = s"${value}L" "" // eval() version may be faster for non-primitive types case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 15ceb9193a8c..28f616fbb9ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -52,10 +52,10 @@ abstract class LeafMathExpression(c: Double, name: String) * @param f The math function. * @param name The short name of the function */ -abstract class UnaryMathExpression(f: Double => Double, name: String) +abstract class UnaryMathExpression(val f: Double => Double, name: String) extends UnaryExpression with Serializable with ImplicitCastInputTypes { - override def inputTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -89,7 +89,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if ($c <= $yAsymptote) { ${ev.isNull} = true; } else { - ${ev.primitive} = java.lang.Math.${funcName}($c); + ${ev.value} = java.lang.Math.${funcName}($c); } """ ) @@ -152,7 +152,31 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") -case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.ceil()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } + } +} case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @@ -182,8 +206,8 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" - ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to); - if (${ev.primitive} == null) { + ${ev.value} = $numconv.convert($num.getBytes(), $from, $to); + if (${ev.value} == null) { ${ev.isNull} = true; } """ @@ -195,7 +219,31 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") -case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.floor()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } + } +} object Factorial { @@ -252,7 +300,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas if ($eval > 20 || $eval < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = + ${ev.value} = org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval); } """ @@ -270,7 +318,7 @@ case class Log2(child: Expression) if ($c <= $yAsymptote) { ${ev.isNull} = true; } else { - ${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2); + ${ev.value} = java.lang.Math.log($c) / java.lang.Math.log(2); } """ ) @@ -414,7 +462,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") - s"${ev.primitive} = " + (child.dataType match { + s"${ev.value} = " + (child.dataType match { case StringType => s"""$hex.hex($c.getBytes());""" case _ => s"""$hex.hex($c);""" }) @@ -440,8 +488,8 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" - ${ev.primitive} = $hex.unhex($c.getBytes()); - ${ev.isNull} = ${ev.primitive} == null; + ${ev.value} = $hex.unhex($c.getBytes()); + ${ev.isNull} = ${ev.value} == null; """ }) } @@ -587,7 +635,7 @@ case class Logarithm(left: Expression, right: Expression) if ($c2 <= 0.0) { ${ev.isNull} = true; } else { - ${ev.primitive} = java.lang.Math.log($c2); + ${ev.value} = java.lang.Math.log($c2); } """) } else { @@ -596,7 +644,7 @@ case class Logarithm(left: Expression, right: Expression) if ($c1 <= 0.0 || $c2 <= 0.0) { ${ev.isNull} = true; } else { - ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); + ${ev.value} = java.lang.Math.log($c2) / java.lang.Math.log($c1); } """) } @@ -709,74 +757,74 @@ case class Round(child: Expression, scale: Expression) val evaluationCode = child.dataType match { case _: DecimalType => s""" - if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { - ${ev.primitive} = ${ce.primitive}; + if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale})) { + ${ev.value} = ${ce.value}; } else { ${ev.isNull} = true; }""" case ByteType => if (_scale < 0) { s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + ${ev.value} = new java.math.BigDecimal(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" } else { - s"${ev.primitive} = ${ce.primitive};" + s"${ev.value} = ${ce.value};" } case ShortType => if (_scale < 0) { s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + ${ev.value} = new java.math.BigDecimal(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" } else { - s"${ev.primitive} = ${ce.primitive};" + s"${ev.value} = ${ce.value};" } case IntegerType => if (_scale < 0) { s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + ${ev.value} = new java.math.BigDecimal(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" } else { - s"${ev.primitive} = ${ce.primitive};" + s"${ev.value} = ${ce.value};" } case LongType => if (_scale < 0) { s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + ${ev.value} = new java.math.BigDecimal(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" } else { - s"${ev.primitive} = ${ce.primitive};" + s"${ev.value} = ${ce.value};" } case FloatType => // if child eval to NaN or Infinity, just return it. if (_scale == 0) { s""" - if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})){ + ${ev.value} = ${ce.value}; } else { - ${ev.primitive} = Math.round(${ce.primitive}); + ${ev.value} = Math.round(${ce.value}); }""" } else { s""" - if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})){ + ${ev.value} = ${ce.value}; } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); }""" } case DoubleType => // if child eval to NaN or Infinity, just return it. if (_scale == 0) { s""" - if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){ + ${ev.value} = ${ce.value}; } else { - ${ev.primitive} = Math.round(${ce.primitive}); + ${ev.value} = Math.round(${ce.value}); }""" } else { s""" - if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){ + ${ev.value} = ${ce.value}; } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); }""" } @@ -785,13 +833,13 @@ case class Round(child: Expression, scale: Expression) if (scaleV == null) { // if scale is null, no need to eval its child at all s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """ } else { s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8d8d66ddeb34..0f6d02f2e00c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -92,18 +92,18 @@ case class Sha2(left: Expression, right: Expression) try { java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); md.update($eval1); - ${ev.primitive} = UTF8String.fromBytes(md.digest()); + ${ev.value} = UTF8String.fromBytes(md.digest()); } catch (java.security.NoSuchAlgorithmException e) { ${ev.isNull} = true; } } else if ($eval2 == 256 || $eval2 == 0) { - ${ev.primitive} = + ${ev.value} = UTF8String.fromString($digestUtils.sha256Hex($eval1)); } else if ($eval2 == 384) { - ${ev.primitive} = + ${ev.value} = UTF8String.fromString($digestUtils.sha384Hex($eval1)); } else if ($eval2 == 512) { - ${ev.primitive} = + ${ev.value} = UTF8String.fromString($digestUtils.sha512Hex($eval1)); } else { ${ev.isNull} = true; @@ -155,7 +155,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp s""" $CRC32 checksum = new $CRC32(); checksum.update($value, 0, $value.length); - ${ev.primitive} = checksum.getValue(); + ${ev.value} = checksum.getValue(); """ }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6f173b52ad9b..26b6aca79971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.UUID + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -24,16 +26,23 @@ import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId: ExprId = ExprId(curId.getAndIncrement()) + private[expressions] val jvmId = UUID.randomUUID() + def newExprId: ExprId = ExprId(curId.getAndIncrement(), jvmId) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } /** - * A globally unique (within this JVM) id for a given named expression. + * A globally unique id for a given named expression. * Used to identify which attribute output by a relation is being * referenced in a subsequent computation. + * + * The `id` field is unique within a given JVM, while the `uuid` is used to uniquely identify JVMs. */ -case class ExprId(id: Long) +case class ExprId(id: Long, jvmId: UUID) + +object ExprId { + def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) +} /** * An [[Expression]] that is named. @@ -185,7 +194,9 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => + name == ar.name && dataType == ar.dataType && nullable == ar.nullable && + metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers case _ => false } @@ -194,12 +205,19 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 - h = h * 37 + exprId.hashCode() + h = h * 37 + name.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + nullable.hashCode() h = h * 37 + metadata.hashCode() + h = h * 37 + exprId.hashCode() + h = h * 37 + qualifiers.hashCode() h } @@ -236,14 +254,27 @@ case class AttributeReference( } } + def withExprId(newExprId: ExprId): AttributeReference = { + if (exprId == newExprId) { + this + } else { + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers) + } + } + override def toString: String = s"$name#${exprId.id}$typeSuffix" + + // Since the expression id is not in the first constructor it is missing from the default + // tree string. + override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" } /** * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -256,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala similarity index 77% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d58c4756938c..df4747d4e6f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -62,18 +62,22 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { ${eval.code} if (!${eval.isNull}) { ${ev.isNull} = false; - ${ev.primitive} = ${eval.primitive}; + ${ev.value} = ${eval.value}; } } """ @@ -111,8 +115,8 @@ case class IsNaN(child: Expression) extends UnaryExpression s""" ${eval.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${ev.primitive} = !${eval.isNull} && Double.isNaN(${eval.primitive}); + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value}); """ } } @@ -152,18 +156,18 @@ case class NaNvl(left: Expression, right: Expression) s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { - if (!Double.isNaN(${leftGen.primitive})) { - ${ev.primitive} = ${leftGen.primitive}; + if (!Double.isNaN(${leftGen.value})) { + ${ev.value} = ${leftGen.value}; } else { ${rightGen.code} if (${rightGen.isNull}) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${rightGen.primitive}; + ${ev.value} = ${rightGen.value}; } } } @@ -186,7 +190,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" - ev.primitive = eval.isNull + ev.value = eval.isNull eval.code } } @@ -205,63 +209,19 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" - ev.primitive = s"(!(${eval.isNull}))" + ev.value = s"(!(${eval.isNull}))" eval.code } } -/** - * A predicate that is evaluated to be true if there are at least `n` null values. - */ -case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { - override def nullable: Boolean = false - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" - - private[this] val childrenArray = children.toArray - - override def eval(input: InternalRow): Boolean = { - var numNulls = 0 - var i = 0 - while (i < childrenArray.length && numNulls < n) { - val evalC = childrenArray(i).eval(input) - if (evalC == null) { - numNulls += 1 - } - i += 1 - } - numNulls >= n - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numNulls = ctx.freshName("numNulls") - val code = children.map { e => - val eval = e.gen(ctx) - s""" - if ($numNulls < $n) { - ${eval.code} - if (${eval.isNull}) { - $numNulls += 1; - } - } - """ - }.mkString("\n") - s""" - int $numNulls = 0; - $code - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $numNulls >= $n; - """ - } -} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray @@ -293,7 +253,7 @@ case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predic s""" if ($nonnull < $n) { ${eval.code} - if (!${eval.isNull} && !Double.isNaN(${eval.primitive})) { + if (!${eval.isNull} && !Double.isNaN(${eval.value})) { $nonnull += 1; } } @@ -313,7 +273,7 @@ case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predic int $nonnull = 0; $code boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $nonnull >= $n; + boolean ${ev.value} = $nonnull >= $n; """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala new file mode 100644 index 000000000000..10ec75eca37f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.language.existentials +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.serializer._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ + +/** + * Invokes a static function, returning the result. By default, any of the arguments being null + * will result in returning null instead of calling the function. + * + * @param staticObject The target of the static call. This can either be the object itself + * (methods defined on scala objects), or the class object + * (static methods defined in java). + * @param dataType The expected return type of the function call + * @param functionName The name of the method to call. + * @param arguments An optional list of expressions to pass as arguments to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. + */ +case class StaticInvoke( + staticObject: Any, + dataType: DataType, + functionName: String, + arguments: Seq[Expression] = Nil, + propagateNull: Boolean = true) extends Expression { + + val objectName = staticObject match { + case c: Class[_] => c.getName + case other => other.getClass.getName.stripSuffix("$") + } + override def nullable: Boolean = true + override def children: Seq[Expression] = arguments + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + if (propagateNull) { + val objNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" + ${argGen.map(_.code).mkString("\n")} + + boolean ${ev.isNull} = !$argsNonNull; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + + if ($argsNonNull) { + ${ev.value} = $objectName.$functionName($argString); + $objNullCheck + } + """ + } else { + s""" + ${argGen.map(_.code).mkString("\n")} + + $javaType ${ev.value} = $objectName.$functionName($argString); + final boolean ${ev.isNull} = ${ev.value} == null; + """ + } + } +} + +/** + * Calls the specified function on an object, optionally passing arguments. If the `targetObject` + * expression evaluates to null then null will be returned. + * + * In some cases, due to erasure, the schema may expect a primitive type when in fact the method + * is returning java.lang.Object. In this case, we will generate code that attempts to unbox the + * value automatically. + * + * @param targetObject An expression that will return the object to call the method on. + * @param functionName The name of the method to call. + * @param dataType The expected return type of the function. + * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + */ +case class Invoke( + targetObject: Expression, + functionName: String, + dataType: DataType, + arguments: Seq[Expression] = Nil) extends Expression { + + override def nullable: Boolean = true + override def children: Seq[Expression] = arguments.+:(targetObject) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + lazy val method = targetObject.dataType match { + case ObjectType(cls) => + cls + .getMethods + .find(_.getName == functionName) + .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) + .getReturnType + .getName + case _ => "" + } + + lazy val unboxer = (dataType, method) match { + case (IntegerType, "java.lang.Object") => (s: String) => + s"((java.lang.Integer)$s).intValue()" + case (LongType, "java.lang.Object") => (s: String) => + s"((java.lang.Long)$s).longValue()" + case (FloatType, "java.lang.Object") => (s: String) => + s"((java.lang.Float)$s).floatValue()" + case (ShortType, "java.lang.Object") => (s: String) => + s"((java.lang.Short)$s).shortValue()" + case (ByteType, "java.lang.Object") => (s: String) => + s"((java.lang.Byte)$s).byteValue()" + case (DoubleType, "java.lang.Object") => (s: String) => + s"((java.lang.Double)$s).doubleValue()" + case (BooleanType, "java.lang.Object") => (s: String) => + s"((java.lang.Boolean)$s).booleanValue()" + case _ => identity[String] _ + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val obj = targetObject.gen(ctx) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val objNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val value = unboxer(s"${obj.value}.$functionName($argString)") + + s""" + ${obj.code} + ${argGen.map(_.code).mkString("\n")} + + boolean ${ev.isNull} = ${obj.value} == null; + $javaType ${ev.value} = + ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : ($javaType) $value; + $objNullCheck + """ + } +} + +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + +/** + * Constructs a new instance of the given class, using the result of evaluating the specified + * expressions as arguments. + * + * @param cls The class to construct. + * @param arguments A list of expression to use as arguments to the constructor. + * @param propagateNull When true, if any of the arguments is null, then null will be returned + * instead of trying to construct the object. + * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you + * to manually specify the type when the object in question is a valid internal + * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. + */ +case class NewInstance( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { + private val className = cls.getName + + override def nullable: Boolean = propagateNull + + override def children: Seq[Expression] = arguments + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + + if (propagateNull) { + val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + + s""" + $setup + + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if ($argsNonNull) { + ${ev.value} = $constructorCall; + ${ev.isNull} = false; + } + """ + } else { + s""" + $setup + + $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = ${ev.value} == null; + """ + } + } +} + +/** + * Given an expression that returns on object of type `Option[_]`, this expression unwraps the + * option into the specified Spark SQL datatype. In the case of `None`, the nullbit is set instead. + * + * @param dataType The expected unwrapped option type. + * @param child An expression that returns an `Option` + */ +case class UnwrapOption( + dataType: DataType, + child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val inputObject = child.gen(ctx) + + s""" + ${inputObject.code} + + boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); + $javaType ${ev.value} = + ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); + """ + } +} + +/** + * Converts the result of evaluating `child` into an option, checking both the isNull bit and + * (in the case of reference types) equality with null. + * @param child The expression to evaluate and wrap. + * @param optType The type of this option. + */ +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = ObjectType(classOf[Option[_]]) + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = optType :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val inputObject = child.gen(ctx) + + s""" + ${inputObject.code} + + boolean ${ev.isNull} = false; + scala.Option ${ev.value} = + ${inputObject.isNull} ? + scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); + """ + } +} + +/** + * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * manually, but will instead be passed into the provided lambda function. + */ +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression + with Unevaluable { + + override def nullable: Boolean = true + + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + GeneratedExpressionCode(code = "", value = value, isNull = isNull) + } +} + +object MapObjects { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType): MapObjects = { + val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + MapObjects(loopVar, function(loopVar), inputData) + } +} + +/** + * Applies the given expression to every element of a collection of items, returning the result + * as an ArrayType. This is similar to a typical map operation, but where the lambda function + * is expressed using catalyst expressions. + * + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List + * + * @param loopVar A place holder that used as the loop variable when iterate the collection, and + * used as input for the `lambdaFunction`. It also carries the element type info. + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. + * @param inputData An expression that when evaluted returns a collection object. + */ +case class MapObjects( + loopVar: LambdaVariable, + lambdaFunction: Expression, + inputData: Expression) extends Expression { + + private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" + case IntegerType => (i: String) => s".getInt($i)" + case LongType => (i: String) => s".getLong($i)" + case FloatType => (i: String) => s".getFloat($i)" + case DoubleType => (i: String) => s".getDouble($i)" + case ByteType => (i: String) => s".getByte($i)" + case ShortType => (i: String) => s".getShort($i)" + case BooleanType => (i: String) => s".getBoolean($i)" + case StringType => (i: String) => s".getUTF8String($i)" + case s: StructType => (i: String) => s".getStruct($i, ${s.size})" + case a: ArrayType => (i: String) => s".getArray($i)" + case _: MapType => (i: String) => s".getMap($i)" + case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + } + + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".apply($i)", false) + case ObjectType(cls) if cls.isArray => + (".length", (i: String) => s"[$i]", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) + case ArrayType(t, _) => + val (sqlType, primitiveElement) = t match { + case m: MapType => (m, false) + case s: StructType => (s, false) + case s: StringType => (s, false) + case udt: UserDefinedType[_] => (udt.sqlType, false) + case o => (o, true) + } + (".numElements()", itemAccessorMethod(sqlType), primitiveElement) + } + + override def nullable: Boolean = true + + override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ArrayType(lambdaFunction.dataType) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val elementJavaType = ctx.javaType(loopVar.dataType) + val genInputData = inputData.gen(ctx) + val genFunction = lambdaFunction.gen(ctx) + val dataLength = ctx.freshName("dataLength") + val convertedArray = ctx.freshName("convertedArray") + val loopIndex = ctx.freshName("loopIndex") + + val convertedType = ctx.boxedType(lambdaFunction.dataType) + + // Because of the way Java defines nested arrays, we have to handle the syntax specially. + // Specifically, we have to insert the [$dataLength] in between the type and any extra nested + // array declarations (i.e. new String[1][]). + val arrayConstructor = if (convertedType contains "[]") { + val rawType = convertedType.takeWhile(_ != '[') + val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse + s"new $rawType[$dataLength]$arrayPart" + } else { + s"new $convertedType[$dataLength]" + } + + val loopNullCheck = if (primitiveElement) { + s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + } else { + s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + } + + s""" + ${genInputData.code} + + boolean ${ev.isNull} = ${genInputData.value} == null; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${ev.isNull}) { + $convertedType[] $convertedArray = null; + int $dataLength = ${genInputData.value}$lengthFunction; + $convertedArray = $arrayConstructor; + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $elementJavaType ${loopVar.value} = + ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; + $loopNullCheck + + if (${loopVar.isNull}) { + $convertedArray[$loopIndex] = null; + } else { + ${genFunction.code} + $convertedArray[$loopIndex] = ${genFunction.value}; + } + + $loopIndex += 1; + } + + ${ev.isNull} = false; + ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + } + """ + } +} + +/** + * Constructs a new external row, using the result of evaluating the specified expressions + * as content. + * + * @param children A list of expression to use as content of the external row. + */ +case class CreateExternalRow(children: Seq[Expression]) extends Expression { + override def dataType: DataType = ObjectType(classOf[Row]) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericRow].getName + val values = ctx.freshName("values") + s""" + boolean ${ev.isNull} = false; + final Object[] $values = new Object[${children.size}]; + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }.mkString("\n") + + s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" + } +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" + ctx.addMutableState( + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + + // Code to serialize. + val input = child.gen(ctx) + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $serializer.serialize(${input.value}, null).array(); + } + """ + } + + override def dataType: DataType = BinaryType +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" + ctx.addMutableState( + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + + // Code to serialize. + val input = child.gen(ctx) + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.javaType(dataType)}) + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + } + """ + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) +} + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val instanceGen = beanInstance.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instanceGen.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value + + s""" + ${instanceGen.code} + if (!${instanceGen.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala similarity index 76% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 873f5324c573..6112259fed61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -48,10 +48,14 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case a: ArrayType if order.direction == Ascending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: ArrayType if order.direction == Descending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => - s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => - s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") } @@ -65,6 +69,18 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { } } +object InterpretedOrdering { + + /** + * Creates a [[InterpretedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { + new InterpretedOrdering(dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + } +} + object RowOrdering { /** @@ -74,6 +90,8 @@ object RowOrdering { case NullType => true case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } @@ -81,13 +99,4 @@ object RowOrdering { * Returns true iff outputs from the expressions can be ordered. */ def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) - - /** - * Creates a [[RowOrdering]] for the given schema, in natural ascending order. - */ - def forSchema(dataTypes: Seq[DataType]): RowOrdering = { - new RowOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 30b7f8d3766a..f1fa13daa77e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructField, StructType} /** * A set of classes that can be used to represent trees of relational expressions. A key goal of @@ -80,4 +81,15 @@ package object expressions { /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection } + + + /** + * Helper functions for working with `Seq[Attribute]`. + */ + implicit class AttributeSeq(attrs: Seq[Attribute]) { + /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ + def toStructType: StructType = { + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b69bbabee7e8..304b438c84ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -64,9 +65,18 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method - * can be used to determine when is is acceptable to move expression evaluation within a query + * can be used to determine when it is acceptable to move expression evaluation within a query * plan. * * For example consider a join between two relations R(a, b) and S(c, d). @@ -97,31 +107,124 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { +case class In(value: Expression, list: Seq[Expression]) extends Predicate + with ImplicitCastInputTypes { + + require(list != null, "list should not be null") + + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + override def children: Seq[Expression] = value +: list - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - list.exists(e => e.eval(input) == evaluatedValue) + if (evaluatedValue == null) { + null + } else { + var hasNull = false + list.foreach { e => + val v = e.eval(input) + if (v == evaluatedValue) { + return true + } else if (v == null) { + hasNull = true + } + } + if (hasNull) { + null + } else { + false + } + } } -} + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val valueGen = value.gen(ctx) + val listGen = list.map(_.gen(ctx)) + val listCode = listGen.map(x => + s""" + if (!${ev.value}) { + ${x.code} + if (${x.isNull}) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(value.dataType, valueGen.value, x.value)}) { + ${ev.isNull} = false; + ${ev.value} = true; + } + } + """).mkString("\n") + s""" + ${valueGen.code} + boolean ${ev.value} = false; + boolean ${ev.isNull} = ${valueGen.isNull}; + if (!${ev.isNull}) { + $listCode + } + """ + } +} /** * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate with CodegenFallback { +case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { + + require(hset != null, "hset could not be null") - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: InternalRow): Any = { - hset.contains(child.eval(input)) + @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + + override def nullable: Boolean = child.nullable || hasNull + + protected override def nullSafeEval(value: Any): Any = { + if (hset.contains(value)) { + true + } else if (hasNull) { + null + } else { + false + } + } + + def getHSet(): Set[Any] = hset + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val setName = classOf[Set[Any]].getName + val InSetName = classOf[InSet].getName + val childGen = child.gen(ctx) + ctx.references += this + val hsetTerm = ctx.freshName("hset") + val hasNullTerm = ctx.freshName("hasNull") + ctx.addMutableState(setName, hsetTerm, + s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; + boolean ${ev.value} = false; + if (!${ev.isNull}) { + ${ev.value} = $hsetTerm.contains(${childGen.value}); + if (!${ev.value} && $hasNullTerm) { + ${ev.isNull} = true; + } + } + """ } } @@ -157,14 +260,14 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with s""" ${eval1.code} boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; + boolean ${ev.value} = false; - if (!${eval1.isNull} && !${eval1.primitive}) { + if (!${eval1.isNull} && !${eval1.value}) { } else { ${eval2.code} - if (!${eval2.isNull} && !${eval2.primitive}) { + if (!${eval2.isNull} && !${eval2.value}) { } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.primitive} = true; + ${ev.value} = true; } else { ${ev.isNull} = true; } @@ -206,14 +309,14 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P s""" ${eval1.code} boolean ${ev.isNull} = false; - boolean ${ev.primitive} = true; + boolean ${ev.value} = true; - if (!${eval1.isNull} && ${eval1.primitive}) { + if (!${eval1.isNull} && ${eval1.value}) { } else { ${eval2.code} - if (!${eval2.isNull} && ${eval2.primitive}) { + if (!${eval2.isNull} && ${eval2.value}) { } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.primitive} = false; + ${ev.value} = false; } else { ${ev.isNull} = true; } @@ -309,10 +412,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive) + val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.isNull = "false" eval1.code + eval2.code + s""" - boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && $equalCode); """ } @@ -325,7 +428,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def symbol: String = "<" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -337,7 +440,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -349,7 +452,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar override def symbol: String = ">" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -361,7 +464,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 62d3d204ca87..8bde8cb9fe87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -69,7 +69,7 @@ case class Rand(seed: Long) extends RDG { s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" - final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble(); """ } } @@ -92,7 +92,7 @@ case class Randn(seed: Long) extends RDG { s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" - final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian(); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala new file mode 100644 index 000000000000..adef6050c356 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.regex.{MatchResult, Pattern} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait StringRegexExpression extends ImplicitCastInputTypes { + self: BinaryExpression => + + def escape(v: String): String + def matches(regex: Pattern, str: String): Boolean + + override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + // try cache the pattern for Literal + private lazy val cache: Pattern = right match { + case x @ Literal(value: String, StringType) => compile(value) + case _ => null + } + + protected def compile(str: String): Pattern = if (str == null) { + null + } else { + // Let it raise exception if couldn't compile the regex string + Pattern.compile(escape(str)) + } + + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) + if(regex == null) { + null + } else { + matches(regex, input1.asInstanceOf[UTF8String].toString) + } + } +} + + +/** + * Simple RegEx pattern matching function + */ +case class Like(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) + + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } +} + + +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + + override def escape(v: String): String = v + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } +} + + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, (str, pattern) => + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") + } + + override def prettyName: String = "split" +} + + +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + UTF8String.fromString(result.toString) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState("UTF8String", + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep.clone(); + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; + """ + }) + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 + } + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.value} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.value} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d04434b953e4..cfc68fc00bea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,8 +19,143 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def toString: String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. @@ -39,6 +174,13 @@ abstract class MutableRow extends InternalRow { def setLong(i: Int, value: Long): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } def setDouble(i: Int, value: Double): Unit = { update(i, value) } + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } } @@ -76,15 +218,15 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def genericGet(ordinal: Int): Any = values(ordinal) + override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -103,15 +245,15 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def genericGet(ordinal: Int): Any = values(ordinal) + override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala deleted file mode 100644 index 5b0fe8dfe2fc..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - -/** The data type for expressions returning an OpenHashSet as the result. */ -private[sql] class OpenHashSetUDT( - val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { - - override def sqlType: DataType = ArrayType(elementType) - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def serialize(obj: Any): Seq[Any] = { - obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq - } - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def deserialize(datum: Any): OpenHashSet[Any] = { - val iterator = datum.asInstanceOf[Seq[Any]].iterator - val set = new OpenHashSet[Any] - while(iterator.hasNext) { - set.add(iterator.next()) - } - - set - } - - override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] - - private[spark] override def asNullable: OpenHashSetUDT = this -} - -/** - * Creates a new set of the specified type - */ -case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { - - override def nullable: Boolean = false - - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - - override def eval(input: InternalRow): Any = { - new OpenHashSet[Any]() - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - elementType match { - case IntegerType | LongType => - ev.isNull = "false" - s""" - ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"new Set($dataType)" -} - -/** - * Adds an item to a set. - * For performance, this expression mutates its input during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class AddItemToSet(item: Expression, set: Expression) - extends Expression with CodegenFallback { - - override def children: Seq[Expression] = item :: set :: Nil - - override def nullable: Boolean = set.nullable - - override def dataType: DataType = set.dataType - - override def eval(input: InternalRow): Any = { - val itemEval = item.eval(input) - val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] - - if (itemEval != null) { - if (setEval != null) { - setEval.add(itemEval) - setEval - } else { - null - } - } else { - setEval - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = "false" - ev.primitive = setEval.primitive - itemEval.code + setEval.code + s""" - if (!${itemEval.isNull} && !${setEval.isNull}) { - (($htype)${setEval.primitive}).add(${itemEval.primitive}); - } - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"$set += $item" -} - -/** - * Combines the elements of two sets. - * For performance, this expression mutates its left input set during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CombineSets(left: Expression, right: Expression) - extends BinaryExpression with CodegenFallback { - - override def nullable: Boolean = left.nullable - override def dataType: DataType = left.dataType - - override def eval(input: InternalRow): Any = { - val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] - if(leftEval != null) { - val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] - if (rightEval != null) { - val iterator = rightEval.iterator - while(iterator.hasNext) { - val rightValue = iterator.next() - leftEval.add(rightValue) - } - } - leftEval - } else { - null - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = leftEval.isNull - ev.primitive = leftEval.primitive - leftEval.code + rightEval.code + s""" - if (!${leftEval.isNull} && !${rightEval.isNull}) { - ${leftEval.primitive}.union((${htype})${rightEval.primitive}); - } - """ - case _ => super.genCode(ctx, ev) - } - } -} - -/** - * Returns the number of elements in the input set. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { - - override def dataType: DataType = LongType - - protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[OpenHashSet[Any]].size.toLong - - override def toString: String = s"$child.count()" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala similarity index 67% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 56225290cd6b..8770c4b76c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -18,17 +18,13 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat -import java.util.Arrays -import java.util.Locale -import java.util.regex.{MatchResult, Pattern} - -import org.apache.commons.lang3.StringEscapeUtils +import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines expressions for string operations. @@ -55,12 +51,12 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.primitive}" + s"${eval.isNull} ? null : ${eval.value}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; - UTF8String ${ev.primitive} = UTF8String.concat($inputs); - if (${ev.primitive} == null) { + UTF8String ${ev.value} = UTF8String.concat($inputs); + if (${ev.value} == null) { ${ev.isNull} = true; } """ @@ -75,7 +71,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ case class ConcatWs(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, s"$prettyName requires at least one argument.") @@ -109,157 +105,56 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.value}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" - UTF8String ${ev.primitive} = UTF8String.concatWs($inputs); - boolean ${ev.isNull} = ${ev.primitive} == null; + UTF8String ${ev.value} = UTF8String.concatWs($inputs); + boolean ${ev.isNull} = ${ev.value} == null; """ } else { - // Contains a mix of strings and arrays. Fall back to interpreted mode for now. - super.genCode(ctx, ev) - } - } -} - - -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => - - def escape(v: String): String - def matches(regex: Pattern, str: String): Boolean - - override def dataType: DataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - // try cache the pattern for Literal - private lazy val cache: Pattern = right match { - case x @ Literal(value: String, StringType) => compile(value) - case _ => null - } + val array = ctx.freshName("array") + val varargNum = ctx.freshName("varargNum") + val idxInVararg = ctx.freshName("idxInVararg") - protected def compile(str: String): Pattern = if (str == null) { - null - } else { - // Let it raise exception if couldn't compile the regex string - Pattern.compile(escape(str)) - } - - protected def pattern(str: String) = if (cache == null) compile(str) else cache - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val regex = pattern(input2.asInstanceOf[UTF8String].toString()) - if(regex == null) { - null - } else { - matches(regex, input1.asInstanceOf[UTF8String].toString()) - } - } -} - -/** - * Simple RegEx pattern matching function - */ -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = StringUtils.escapeLikeRegex(v) - - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); - """ - }) - } - } -} - - -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = v - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" + val evals = children.map(_.gen(ctx)) + val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => + child.dataType match { + case StringType => + ("", // we count all the StringType arguments num at once below. + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};") + case _: ArrayType => + val size = ctx.freshName("n") + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.value}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.value}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + } + } + """) + } + }.unzip - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile(rightStr); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); - """ - }) + evals.map(_.code).mkString("\n") + + s""" + int $varargNum = ${children.count(_.dataType == StringType) - 1}; + int $idxInVararg = 0; + ${varargCount.mkString("\n")} + UTF8String[] $array = new UTF8String[$varargNum]; + ${varargBuild.mkString("\n")} + UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); + boolean ${ev.isNull} = ${ev.value} == null; + """ } } } - trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -304,7 +199,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ImplicitCastInputTypes { +trait StringPredicate extends Predicate with ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -321,7 +216,7 @@ trait StringComparison extends ImplicitCastInputTypes { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -332,7 +227,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -343,13 +238,110 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } +object StringTranslate { + + def buildDict(matchingString: UTF8String, replaceString: UTF8String) + : JMap[Character, Character] = { + val matching = matchingString.toString() + val replace = replaceString.toString() + val dict = new HashMap[Character, Character]() + var i = 0 + while (i < matching.length()) { + val rep = if (i < replace.length()) replace.charAt(i) else '\u0000' + if (null == dict.get(matching.charAt(i))) { + dict.put(matching.charAt(i), rep) + } + i += 1 + } + dict + } +} + +/** + * A function translate any character in the `srcExpr` by a character in `replaceExpr`. + * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`. + * The translate will happen when any character in the string matching with the character + * in the `matchingExpr`. + */ +case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + @transient private var lastMatching: UTF8String = _ + @transient private var lastReplace: UTF8String = _ + @transient private var dict: JMap[Character, Character] = _ + + override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { + if (matchingEval != lastMatching || replaceEval != lastReplace) { + lastMatching = matchingEval.asInstanceOf[UTF8String].clone() + lastReplace = replaceEval.asInstanceOf[UTF8String].clone() + dict = StringTranslate.buildDict(lastMatching, lastReplace) + } + srcEval.asInstanceOf[UTF8String].translate(dict) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastMatching = ctx.freshName("lastMatching") + val termLastReplace = ctx.freshName("lastReplace") + val termDict = ctx.freshName("dict") + val classNameDict = classOf[JMap[Character, Character]].getCanonicalName + + ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") + ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + + nullSafeCodeGen(ctx, ev, (src, matching, replace) => { + val check = if (matchingExpr.foldable && replaceExpr.foldable) { + s"${termDict} == null" + } else { + s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + } + s"""if ($check) { + // Not all of them is literal or matching or replace value changed + ${termLastMatching} = ${matching}.clone(); + ${termLastReplace} = ${replace}.clone(); + ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict(${termLastMatching}, ${termLastReplace}); + } + ${ev.value} = ${src}.translate(${termDict}); + """ + }) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil + override def prettyName: String = "translate" +} + +/** + * A function that returns the index (1-based) of the given string (left) in the comma- + * delimited list (right). Returns 0, if the string wasn't found or if the given + * string (left) contains a comma. + */ +case class FindInSet(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override protected def nullSafeEval(word: Any, set: Any): Any = + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (word, set) => + s"${ev.value} = $set.findInSet($word);" + ) + } + + override def dataType: DataType = IntegerType +} + /** * A function that trim the spaces from both ends for the specified string. */ @@ -452,13 +444,14 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil + override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -484,6 +477,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val substrGen = substr.gen(ctx) + val strGen = str.gen(ctx) + val startGen = start.gen(ctx) + s""" + int ${ev.value} = 0; + boolean ${ev.isNull} = false; + ${startGen.code} + if (!${startGen.isNull}) { + ${substrGen.code} + if (!${substrGen.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, + ${startGen.value}) + 1; + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "locate" } @@ -569,9 +587,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { // Java primitives get boxed in order to allow null values. s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + s"new ${ctx.boxedType(v._1)}(${v._2.value})" } else { - s"(${v._2.isNull}) ? null : ${v._2.primitive}" + s"(${v._2.isNull}) ? null : ${v._2.value}" } s + "," + nullSafeString }) @@ -583,13 +601,13 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${argListCode.mkString} $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); - $form.format(${pattern.primitive}.toString() $argListString); - ${ev.primitive} = UTF8String.fromString($sb.toString()); + $form.format(${pattern.value}.toString() $argListString); + ${ev.value} = UTF8String.fromString($sb.toString()); } """ } @@ -665,66 +683,12 @@ case class StringSpace(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (length) => - s"""${ev.primitive} = UTF8String.blankString(($length < 0) ? 0 : $length);""") + s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } override def prettyName: String = "space" } -/** - * Splits str around pat (pattern is a regular expression). - */ -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = pattern - override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) - new GenericArrayData(strings.asInstanceOf[Array[Any]]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => - // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") - } - - override def prettyName: String = "split" -} - -object Substring { - def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = { - if (pos > bytes.length) { - return Array[Byte]() - } - - var start = if (pos > 0) { - pos - 1 - } else if (pos < 0) { - bytes.length + pos - } else { - 0 - } - - val end = if ((bytes.length - start) < len) { - bytes.length - } else { - start + len - } - - start = Math.max(start, 0) // underflow - if (start < end) { - Arrays.copyOfRange(bytes, start, end) - } else { - Array[Byte]() - } - } -} /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. @@ -747,18 +711,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) str.dataType match { case StringType => string.asInstanceOf[UTF8String] .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) - case BinaryType => Substring.subStringBinarySQL(string.asInstanceOf[Array[Byte]], + case BinaryType => ByteArray.subStringSQL(string.asInstanceOf[Array[Byte]], pos.asInstanceOf[Int], len.asInstanceOf[Int]) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val cls = classOf[Substring].getName defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { case StringType => s"$string.substringSQL($pos, $len)" - case BinaryType => s"$cls.subStringBinarySQL($string, $pos, $len)" + case BinaryType => s"${classOf[ByteArray].getName}.subStringSQL($string, $pos, $len)" } }) } @@ -798,7 +761,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (left, right) => - s"${ev.primitive} = $left.levenshteinDistance($right);") + s"${ev.value} = $left.levenshteinDistance($right);") } } @@ -841,9 +804,9 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp s""" byte[] $bytes = $child.getBytes(); if ($bytes.length > 0) { - ${ev.primitive} = (int) $bytes[0]; + ${ev.value} = (int) $bytes[0]; } else { - ${ev.primitive} = 0; + ${ev.value} = 0; } """}) } @@ -865,7 +828,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.primitive} = UTF8String.fromBytes( + s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } @@ -886,7 +849,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s""" - ${ev.primitive} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); + ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); """}) } } @@ -913,9 +876,9 @@ case class Decode(bin: Expression, charset: Expression) nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { - ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); + ${ev.value} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); } """) } @@ -943,170 +906,13 @@ case class Encode(value: Expression, charset: Expression) nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { - ${ev.primitive} = $string.toString().getBytes($charset.toString()); + ${ev.value} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); }""") } } -/** - * Replace all substrings of str that match regexp with rep. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - // last replacement string, we don't want to convert a UTF8String => java.langString every time. - @transient private var lastReplacement: String = _ - @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private val result: StringBuffer = new StringBuffer - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) - - UTF8String.fromString(result.toString) - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = subject :: regexp :: rep :: Nil - override def prettyName: String = "regexp_replace" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") - - val classNamePattern = classOf[Pattern].getCanonicalName - val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - - ctx.addMutableState("UTF8String", - termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, - termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", - termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") - - nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!$rep.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); - ${ev.isNull} = false; - """ - }) - } -} - -/** - * Extract a specific(idx) group identified by a Java regex. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - def this(s: Expression, r: Expression) = this(s, r, Literal(1)) - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } else { - UTF8String.EMPTY_UTF8 - } - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil - override def prettyName: String = "regexp_extract" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - val classNamePattern = classOf[Pattern].getCanonicalName - - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - - nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - java.util.regex.Matcher m = - ${termPattern}.matcher($subject.toString()); - if (m.find()) { - java.util.regex.MatchResult mr = m.toMatchResult(); - ${ev.primitive} = UTF8String.fromString(mr.group($idx)); - ${ev.isNull} = false; - } else { - ${ev.primitive} = UTF8String.EMPTY_UTF8; - ${ev.isNull} = false; - }""" - }) - } -} - /** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or @@ -1208,10 +1014,10 @@ case class FormatNumber(x: Expression, d: Expression) $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; $numberFormat.applyPattern($dFormat.toPattern()); - ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else { - ${ev.primitive} = null; + ${ev.value} = null; ${ev.isNull} = true; } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa4..06252ac4fc61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.catalyst.expressions.aggregate.{NoOp, DeclarativeAggregate} +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -71,9 +72,6 @@ case class WindowSpecDefinition( childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] - - override def toString: String = simpleString - override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -120,6 +118,19 @@ sealed trait FrameBoundary { def notFollows(other: FrameBoundary): Boolean } +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundary { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} + /** UNBOUNDED PRECEDING boundary. */ case object UnboundedPreceding extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { @@ -246,85 +257,281 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} - def reset(): Unit +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression - def prepareInputParameters(input: InternalRow): AnyRef + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression - def update(input: AnyRef): Unit + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression - def batchUpdate(inputs: Array[AnyRef]): Unit + /** + * Direction (above = 1/below = -1) of the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val direction: SortDirection - def evaluate(): Unit + override def children: Seq[Expression] = Seq(input, offset, default) - def get(index: Int): Any + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = input.foldable && (default == null || default.foldable) + + override def nullable: Boolean = input.nullable && (default == null || default.nullable) - def newInstance(): WindowFunction + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) + + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + protected def n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() +} + +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber +} + +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) +} + +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(1)) + + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant + // for each partition. + buckets.eval() match { + case b: Int if b > 0 => // Ok + case x => throw new AnalysisException( + "Buckets expression must be a foldable positive integer expression: $x") + } + + private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() + private val bucketThreshold = + AttributeReference("bucketThreshold", IntegerType, nullable = false)() + private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() + private val bucketsWithPadding = + AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() + private def bucketOverflow(e: Expression) = + If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + + override val aggBufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + zero, + zero, + zero, + Cast(Divide(n, buckets), IntegerType), + Cast(Remainder(n, buckets), IntegerType) + ) + + override val updateExpressions = Seq( + Add(rowNumber, one), + Add(bucket, bucketOverflow(one)), + Add(bucketThreshold, bucketOverflow( + Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + NoOp, + NoOp + ) + + override val evaluateExpression = bucket } /** - * Extractor for making working with frame boundaries easier. + * A RankLike function is a WindowFunction that changes its value based on a change in the value of + * the order of the window in which is processed. For instance, when the value of 'x' changes in a + * window ordered by 'x' the rank function also changes. The size of the change of the rank function + * is (typically) not dependent on the size of the change in 'x'. */ -object FrameBoundaryExtractor { - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None +abstract class RankLike extends AggregateWindowFunction { + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + /** Store the values of the window 'order' expressions. */ + protected val orderAttrs = children.map{ expr => + AttributeReference(expr.prettyString, expr.dataType)() } + + /** Predicate that detects if the order attributes have changed. */ + protected val orderEquals = children.zip(orderAttrs) + .map(EqualNullSafe.tupled) + .reduceOption(And) + .getOrElse(Literal(true)) + + protected val orderInit = children.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType, nullable = false)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + protected val zero = Literal(0) + protected val one = Literal(1) + protected val increaseRowNumber = Add(rowNumber, one) + + /** + * Different RankLike implementations use different source expressions to update their rank value. + * Rank for instance uses the number of rows seen, whereas DenseRank uses the number of changes. + */ + protected def rankSource: Expression = rowNumber + + /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ + protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + + override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs + override val initialValues = zero +: one +: orderInit + override val updateExpressions = increaseRank +: increaseRowNumber +: children + override val evaluateExpression: Expression = rank + + def withOrder(order: Seq[Expression]): RankLike +} + +case class Rank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) +} + +case class DenseRank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def rankSource = Add(rank, one) + override val updateExpressions = increaseRank +: children + override val aggBufferAttributes = rank +: orderAttrs + override val initialValues = zero +: orderInit +} + +case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) + override def dataType: DataType = DoubleType + override val evaluateExpression = If(GreaterThan(n, one), + Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), + Literal(0.0d)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4b6294dc7b8..f6088695a927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,27 +18,20 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries + +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -class DefaultOptimizer extends Optimizer { - - /** - * Override to provide additional rules for the "Operator Optimizations" batch. - */ - val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - - lazy val batches = +object DefaultOptimizer extends Optimizer { + val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -47,27 +40,28 @@ class DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown :: - SamplePushDown :: - PushPredicateThroughJoin :: - PushPredicateThroughProject :: - PushPredicateThroughGenerate :: - ColumnPruning :: + SetOperationPushDown, + SamplePushDown, + ReorderJoin, + PushPredicateThroughJoin, + PushPredicateThroughProject, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, // Operator combine - ProjectCollapsing :: - CombineFilters :: - CombineLimits :: + ProjectCollapsing, + CombineFilters, + CombineLimits, // Constant folding - NullPropagation :: - OptimizeIn :: - ConstantFolding :: - LikeSimplification :: - BooleanSimplification :: - RemovePositive :: - SimplifyFilters :: - SimplifyCasts :: - SimplifyCaseConversionExpressions :: - extendedOperatorOptimizationRules.toList : _*) :: + NullPropagation, + OptimizeIn, + ConstantFolding, + LikeSimplification, + BooleanSimplification, + RemoveDispensableExpressions, + SimplifyFilters, + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -80,10 +74,6 @@ class DefaultOptimizer extends Optimizer { object SamplePushDown extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into sample - case Filter(condition, s @ Sample(lb, up, replace, seed, child)) => - Sample(lb, up, replace, seed, - Filter(condition, child)) // Push down projection into sample case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => Sample(lb, up, replace, seed, @@ -92,9 +82,24 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * with deterministic condition. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -121,48 +126,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into union - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } @@ -172,19 +194,35 @@ object SetOperationPushDown extends Rule[LogicalPlan] { * * - Inserting Projections beneath the following operators: * - Aggregate + * - Generate * - Project <- Join * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + // Eliminate attributes that are not needed to calculate the Generate. + case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => + g.copy(child = Project(g.references.toSeq, g.child)) + + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => + p.copy(child = g.copy(join = false)) + + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { + p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + } + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) if (a.outputSet -- p.references).nonEmpty => Project( @@ -219,28 +257,33 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // Push down project if possible when the child is sort - case p @ Project(projectList, s @ Sort(_, _, grandChild)) - if s.references.subsetOf(p.outputSet) => - s.copy(child = Project(projectList, grandChild)) + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - // We need to preserve the nullability of c's output. - // So, we first create a outputMap and if a reference is from the output of - // c, we use that output attribute from c. - val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) - val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq - Project(projectList, c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } - } } /** @@ -273,8 +316,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] { val substitutedProjection = projectList1.map(_.transform { case a: Attribute => aliasMap.getOrElse(a, a) }).asInstanceOf[Seq[NamedExpression]] - - Project(substitutedProjection, child) + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + Project(cleanedProjection, child) } } } @@ -315,9 +361,15 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -329,14 +381,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ Count(expr) if !expr.nullable => Count(Literal(1)) + case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + // This rule should be only triggered when isDistinct field is false. + AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -366,11 +417,16 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => e } - case e: StringComparison => e.children match { + case e: StringPredicate => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } + + // If the value expression is NULL then transform the In expression to + // Literal(null) + case In(Literal(null, _), list) => Literal.create(null, BooleanType) + } } } @@ -389,12 +445,6 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) - - // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. - case In(Literal(v, _), list) if list.exists { - case Literal(candidate, _) if candidate == v => true - case _ => false - } => Literal.create(true, BooleanType) } } } @@ -406,7 +456,7 @@ object ConstantFolding extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 => val hSet = list.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } @@ -434,6 +484,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case (_, Literal(false, BooleanType)) => Literal(false) // a && a => a case (l, r) if l fastEquals r => l + // a && (not(a) || b) => a && b + case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) + case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, @@ -512,6 +567,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case LessThan(l, r) => GreaterThanOrEqual(l, r) // not(l <= r) => l > r case LessThanOrEqual(l, r) => GreaterThan(l, r) + // not(l || r) => not(l) && not(r) + case Or(l, r) => And(Not(l), Not(r)) + // not(l && r) => not(l) or not(r) + case And(l, r) => Or(Not(l), Not(r)) // not(not(e)) => e case Not(e) => e case _ => not @@ -530,13 +589,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => - // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and - // Not(AtLeastNNulls(1, e2)) - // (this is used to make sure there is no null in the result of e1 and e2 and - // they are added by FilterNullsInJoinKey optimziation rule), we can - // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). - Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } @@ -589,20 +641,14 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe filter } else { // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val pushedCondition = + deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) - } - } } /** @@ -615,26 +661,108 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) + } else { + filter + } + } +} + +/** + * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only + * non-aggregate attributes (typically literals or grouping expressions). + */ +object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) + + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic + } + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } } } +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -739,11 +867,12 @@ object SimplifyCasts extends Rule[LogicalPlan] { } /** - * Removes [[UnaryPositive]] identify function + * Removes nodes that are not necessary. */ -object RemovePositive extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child + case PromotePrecision(child) => child } } @@ -786,12 +915,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) Cast( - Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 73a21884a471..56a3dd02f9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -51,7 +51,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { * filled in automatically by the QueryPlanner using the other execution strategies that are * available. */ - protected def planLater(plan: LogicalPlan) = this.plan(plan).next() + protected def planLater(plan: LogicalPlan): PhysicalPlan = this.plan(plan).next() def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // Obviously a lot to do here still... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b9ca712c1ee1..cd3f15cbe107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,35 +17,11 @@ package org.apache.spark.sql.catalyst.planning -import scala.annotation.tailrec - import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -/** - * A pattern that matches any number of filter operations on top of another relational operator. - * Adjacent filter operators are collected and their conditions are broken up and returned as a - * sequence of conjunctive predicates. - * - * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the - * output and a relational operator. - */ -object FilteredOperation extends PredicateHelper { - type ReturnType = (Seq[Expression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) - - @tailrec - private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { - case Filter(condition, child) => - collectFilters(filters ++ splitConjunctivePredicates(condition), child) - case other => (filters, other) - } -} - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -62,8 +38,9 @@ object PhysicalOperation extends PredicateHelper { } /** - * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two - * examples for alias in-lining/substitution. Before: + * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * Here are two examples for alias in-lining/substitution. + * Before: * {{{ * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 @@ -74,15 +51,15 @@ object PhysicalOperation extends PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - def collectProjectsAndFilters(plan: LogicalPlan): + private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case Project(fields, child) => + case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) - case Filter(condition, child) => + case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) @@ -91,11 +68,11 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { + private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) @@ -106,82 +83,11 @@ object PhysicalOperation extends PredicateHelper { } } -/** - * Matches a logical aggregation that can be performed on distributed data in two steps. The first - * operates on the data in each partition performing partial aggregation for each group. The second - * occurs after the shuffle and completes the aggregation. - * - * This pattern will only match if all aggregate expressions can be computed partially and will - * return the rewritten aggregation expressions for both phases. - * - * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. - */ -object PartialAggregation { - type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - - case e: Expression => - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute - }.getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = namedGroupingExpressions.map(_._2) ++ - partialEvaluations.values.flatMap(_.partialEvaluations) - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child)) - } else { - None - } - case _ => None - } -} - - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. + * + * Null-safe equality will be transformed into equality as joining key (replace null with default + * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ @@ -193,21 +99,29 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = - condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { - case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || - (canEvaluate(l, right) && canEvaluate(r, left)) => true - case _ => false - } - - val joinKeys = joinPredicates.map { - case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false } - val leftKeys = joinKeys.map(_._1) - val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { @@ -217,6 +131,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoins extends PredicateHelper { + + // flatten all inner joins, which are next to each other + def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { + case Join(left, right, Inner, cond) => + val (plans, conditions) = flattenJoin(left) + (plans ++ Seq(right), conditions ++ cond.toSeq) + + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq(plan), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + Some(flattenJoin(f)) + case j @ Join(_, _, Inner, _) => + Some(flattenJoin(j)) + case _ => None + } +} + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c610f70d3843..b9db7838db08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -92,7 +92,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** @@ -124,7 +124,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** Returns the result of running [[transformExpressions]] on this node @@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Returns all of the expressions present in this query plan operator. */ def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Traversable[_] => seqToExpressions(s) + case other => Nil + } + productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seq.flatMap { - case e: Expression => e :: Nil - case other => Nil - } + case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index bedeaf06adf1..8f8747e10593 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,11 +22,60 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { + private var _analyzed: Boolean = false + + /** + * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + */ + private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already + * been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } else { + this + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + this resolveOperators { + case p => p.transformExpressions(r) + } + } + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case @@ -86,16 +135,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { val input = children.flatMap(_.output) + def cleanExpression(e: Expression) = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + BindReferences.bindReference(cleanedExprId, input, allowFailures = true) + case other => BindReferences.bindReference(other, input, allowFailures = true) + } + productIterator.map { // Children are checked using sameResult above. case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case s: Option[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case s: Seq[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case other => other @@ -130,47 +188,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(parseAttributeName(name), output, resolver) - } - - /** - * Internal method, used to split attribute name by dot with backticks rule. - * Backticks must appear in pairs, and the quoted string must be a complete name part, - * which means `ab..c`e.f is not allowed. - * Escape character is not supported now, so we can't use backtick inside name part. - */ - private def parseAttributeName(name: String): Seq[String] = { - val e = new AnalysisException(s"syntax error in attribute name: $name") - val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] - val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] - var inBacktick = false - var i = 0 - while (i < name.length) { - val char = name(i) - if (inBacktick) { - if (char == '`') { - inBacktick = false - if (i + 1 < name.length && name(i + 1) != '.') throw e - } else { - tmp += char - } - } else { - if (char == '`') { - if (tmp.nonEmpty) throw e - inBacktick = true - } else if (char == '.') { - if (name(i - 1) == '.' || i == name.length - 1) throw e - nameParts += tmp.mkString - tmp.clear() - } else { - tmp += char - } - } - i += 1 - } - if (inBacktick) throw e - nameParts += tmp.mkString - nameParts.toSeq + resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) } /** @@ -250,13 +268,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and wrap it with UnresolvedAlias which will be removed later. + // and aliased it with the last part of the name. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as - // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - Some(UnresolvedAlias(fieldExprs)) + Some(Alias(fieldExprs, nestedFields.last)()) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 54b5f4977266..5665fd7e5f41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Utils +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -34,7 +36,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend }.nonEmpty ) - expressions.forall(_.resolved) && childrenResolved && !hasSpecialExpressions + !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } } @@ -86,46 +88,26 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - /** - * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children - * have at least one null value and atLeastNNulls.children are all attributes. - */ - private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { - val expressions = atLeastNNulls.children - val n = atLeastNNulls.n - if (n != 1) { - // AtLeastNNulls is not used to check if atLeastNNulls.children have - // at least one null value. - false - } else { - // AtLeastNNulls is used to check if atLeastNNulls.children have - // at least one null value. We need to make sure all atLeastNNulls.children - // are attributes. - expressions.forall(_.isInstanceOf[Attribute]) + override def output: Seq[Attribute] = child.output +} + +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) } - } - override def output: Seq[Attribute] = condition match { - case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => - // The condition is used to make sure that there is no null value in - // a.children. - val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) - child.output.map { - case attr if nonNullableAttributes.contains(attr) => - attr.withNullability(false) - case attr => attr - } - case _ => child.output - } + final override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } -case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - override def output: Seq[Attribute] = left.output +private[sql] object SetOperation { + def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) +} - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } +case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -133,6 +115,13 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { } } +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} + case class Join( left: LogicalPlan, right: LogicalPlan, @@ -172,15 +161,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } - -case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} - case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], @@ -190,7 +170,7 @@ case class InsertIntoTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { @@ -218,7 +198,7 @@ case class WithWindowDefinition( } /** - * @param order The ordering expressions, should all be [[AttributeReference]] + * @param order The ordering expressions * @param global True means global sorting apply for entire data set, * False means sorting only apply within the partition. * @param child Child logical plan @@ -228,11 +208,6 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - - def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) - - override lazy val resolved: Boolean = - expressions.forall(_.resolved) && childrenResolved && hasNoEvaluation } case class Aggregate( @@ -247,11 +222,9 @@ case class Aggregate( }.nonEmpty ) - expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions + !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } @@ -263,41 +236,25 @@ case class Window( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - (projectList ++ windowExpressions).map(_.toAttribute) + projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + private def buildNonSelectExprSet( + bitmask: Int, + exprs: Seq[Expression]): ArrayBuffer[Expression] = { + val set = new ArrayBuffer[Expression](2) var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + if (((bitmask >> bit) & 1) == 0) set += exprs(bit) bit -= 1 } @@ -305,19 +262,29 @@ case class Expand( } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => + (child.output :+ gid).map(expr => expr transformDown { + // TODO this causes a problem when a column is used both for grouping and aggregation. + case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, expr.dataType) @@ -325,15 +292,32 @@ case class Expand( // replace the groupingId with concrete value (the bit mask) Literal.create(bitmask, IntegerType) }) - - result += substitution } - - result.toSeq + Expand(projections, child.output :+ gid, child) } +} - override def output: Seq[Attribute] = { - child.output :+ gid +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + + override def statistics: Statistics = { + // TODO shouldn't we factor in the size of the projection versus the size of the backing child + // row? + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) } } @@ -344,6 +328,10 @@ trait GroupingAnalytics extends UnaryNode { override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } @@ -407,6 +395,20 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + } + } +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -449,7 +451,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { } /** - * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer * of the output requires some specific ordering or distribution of the data. @@ -475,10 +477,119 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output +/** + * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are + * used respectively to decode/encode from the JVM object representation expected by `func.` + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumns { + def apply[T, U : Encoder]( + func: T => U, + tEncoder: ExpressionEncoder[T], + child: LogicalPlan): AppendColumns[T, U] = { + val attrs = encoderFor[U].schema.toAttributes + new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to + * decode/encode from the JVM object representation expected by `func.` + */ +case class AppendColumns[T, U]( + func: T => U, + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + newColumns: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ newColumns + override def missingInput: AttributeSet = super.missingInput -- newColumns +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K, T, U : Encoder]( + func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + groupingAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups[K, T, U] = { + new MapGroups( + func, + kEncoder, + tEncoder, + encoderFor[U], + groupingAttributes, + encoderFor[U].schema.toAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} + +/** Factory for constructing new `CoGroup` nodes. */ +object CoGroup { + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { + CoGroup( + func, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, + leftGroup, + rightGroup, + left, + right) + } +} + +/** + * A relation produced by applying `func` to each grouping key and associated values from left and + * right children. + */ +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + resultEnc: ExpressionEncoder[Result], + output: Seq[Attribute], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + override def missingInput: AttributeSet = AttributeSet.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 1f76b03bcb0f..a5bdee1b854c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,10 +31,19 @@ case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData /** - * This method repartitions data using [[Expression]]s, and receives information about the - * number of partitions during execution. Used when a specific ordering or distribution is - * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * This method repartitions data using [[Expression]]s into `numPartitions`, and receives + * information about the number of partitions during execution. Used when a specific ordering or + * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. + * If `numPartitions` is not specified, the number of partitions will be the number set by + * `spark.sql.shuffle.partitions`. */ -case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData +case class RepartitionByExpression( + partitionExpressions: Seq[Expression], + child: LogicalPlan, + numPartitions: Option[Int] = None) extends RedistributeData { + numPartitions match { + case Some(n) => require(n > 0, "numPartitions must be greater than 0.") + case None => // Ok + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ec659ce789c2..f6fb31a2af59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Describes how an operator's output is split across partitions. The `compatibleWith`, + * `guarantees`, and `satisfies` methods describe relationships between child partitionings, + * target partitionings, and [[Distribution]]s. These relations are described more precisely in + * their individual method docs, but at a high level: + * + * - `satisfies` is a relationship between partitionings and distributions. + * - `compatibleWith` is relationships between an operator's child output partitionings. + * - `guarantees` is a relationship between a child's existing output partitioning and a target + * output partitioning. + * + * Diagrammatically: + * + * +--------------+ + * | Distribution | + * +--------------+ + * ^ + * | + * satisfies + * | + * +--------------+ +--------------+ + * | Child | | Target | + * +----| Partitioning |----guarantees--->| Partitioning | + * | +--------------+ +--------------+ + * | ^ + * | | + * | compatibleWith + * | | + * +------------+ + * + */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -90,9 +121,66 @@ sealed trait Partitioning { /** * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] * guarantees the same partitioning scheme described by `other`. + * + * Compatibility of partitionings is only checked for operators that have multiple children + * and that require a specific child output [[Distribution]], such as joins. + * + * Intuitively, partitionings are compatible if they route the same partitioning key to the same + * partition. For instance, two hash partitionings are only compatible if they produce the same + * number of output partitionings and hash records according to the same hash function and + * same partitioning key schema. + * + * Put another way, two partitionings are compatible with each other if they satisfy all of the + * same distribution guarantees. */ - // TODO: Add an example once we have the `nullSafe` concept. - def guarantees(other: Partitioning): Boolean + def compatibleWith(other: Partitioning): Boolean + + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees + * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning + * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance + * optimization to allow the exchange planner to avoid redundant repartitionings. By default, + * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number + * of partitions, same strategy (range or hash), etc). + * + * In order to enable more aggressive optimization, this strict equality check can be relaxed. + * For example, say that the planner needs to repartition all of an operator's children so that + * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children + * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens + * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; + * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` + * [[SinglePartition]]. + * + * The SinglePartition example given above is not particularly interesting; guarantees' real + * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion + * of null-safe partitionings, under which partitionings can specify whether rows whose + * partitioning keys contain null values will be grouped into the same partition or whether they + * will have an unknown / random distribution. If a partitioning does not require nulls to be + * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered + * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot + * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a + * symmetric relation. + * + * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows + * produced by `A` could have also been produced by `B`. + */ + def guarantees(other: Partitioning): Boolean = this == other +} + +object Partitioning { + def allCompatible(partitionings: Seq[Partitioning]): Boolean = { + // Note: this assumes transitivity + partitionings.sliding(2).map { + case Seq(a) => true + case Seq(a, b) => + if (a.numPartitions != b.numPartitions) { + assert(!a.compatibleWith(b) && !b.compatibleWith(a)) + false + } else { + a.compatibleWith(b) && b.compatibleWith(a) + } + }.forall(_ == true) + } } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -101,29 +189,35 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } + override def compatibleWith(other: Partitioning): Boolean = false + override def guarantees(other: Partitioning): Boolean = false } -case object SinglePartition extends Partitioning { - val numPartitions = 1 - - override def satisfies(required: Distribution): Boolean = true - - override def guarantees(other: Partitioning): Boolean = other match { - case SinglePartition => true +/** + * Represents a partitioning where rows are distributed evenly across output partitions + * by starting from a random target partition number and distributing rows in a round-robin + * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. + */ +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true case _ => false } + + override def compatibleWith(other: Partitioning): Boolean = false + + override def guarantees(other: Partitioning): Boolean = false } -case object BroadcastPartitioning extends Partitioning { +case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = true - override def guarantees(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case _ => false - } + override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 + + override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -138,20 +232,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: HashPartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this.semanticEquals(o) case _ => false } + } /** @@ -173,20 +270,23 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } } @@ -228,6 +328,13 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) + /** + * Returns true if any `partitioning` of this collection is compatible with + * the given [[Partitioning]]. + */ + override def compatibleWith(other: Partitioning): Boolean = + partitionings.exists(_.compatibleWith(other)) + /** * Returns true if any `partitioning` of this collection guarantees * the given [[Partitioning]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 3f9858b0c4a4..f80d2a93241d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,10 +17,30 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +object RuleExecutor { + protected val timeMap = AtomicLongMap.create[String]() + + /** Resets statistics about time spent running specific rules */ + def resetTime(): Unit = timeMap.clear() + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val map = timeMap.asMap().asScala + val maxSize = map.keys.map(_.toString.length).max + map.toSeq.sortBy(_._2).reverseMap { case (k, v) => + s"${k.padTo(maxSize, " ").mkString} $v" + }.mkString("\n") + } +} + abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** @@ -41,6 +61,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected val batches: Seq[Batch] + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -58,7 +79,11 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { while (continue) { curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => + val startTime = System.nanoTime() val result = rule(plan) + val runTime = System.nanoTime() - startTime + RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) + if (!result.fastEquals(plan)) { logTrace( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 122e9fc5ed77..d838d845d20f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.trees +import scala.collection.Map + import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StructType, DataType} /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -149,7 +151,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a copy of this node where `f` has been applied to all the nodes children. */ - def mapChildren(f: BaseType => BaseType): this.type = { + def mapChildren(f: BaseType => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => @@ -170,12 +172,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - def withNewChildren(newChildren: Seq[BaseType]): this.type = { + def withNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer val newArgs = productIterator.map { + case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { case arg: TreeNode[_] if containsChild(arg) => @@ -190,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -229,9 +245,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { - transformChildrenDown(rule) + transformChildren(rule, (t, r) => t.transformDown(r)) } else { - afterRule.transformChildrenDown(rule) + afterRule.transformChildren(rule, (t, r) => t.transformDown(r)) } } @@ -240,11 +256,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * this node. When `rule` does not apply to a given node it is left unchanged. * @param rule the function used to transform this nodes children */ - def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { + protected def transformChildren( + rule: PartialFunction[BaseType, BaseType], + nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -252,18 +270,28 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true Some(newChild) } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -285,7 +313,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildrenUp(rule) + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -297,44 +325,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { - var changed = false - val newArgs = productIterator.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - Some(newChild) - } else { - Some(arg) - } - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case other => other - } - case nonChild: AnyRef => nonChild - case null => null - }.toArray - if (changed) makeCopy(newArgs) else this - } - /** * Args to the constructor that should be copied, but not transformed. * These are appended to the transformed args automatically by makeCopy @@ -348,7 +338,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * that are not present in the productIterator. * @param newArgs the new product arguments. */ - def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") @@ -359,9 +349,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] } } } catch { @@ -373,6 +363,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { |Is otherCopyArgs specified correctly for $nodeName. |Exception message: ${e.getMessage} |ctor: $defaultCtor? + |types: ${newArgs.map(_.getClass).mkString(", ")} |args: ${newArgs.mkString(", ")} """.stripMargin) } @@ -389,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil @@ -402,7 +393,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -428,12 +419,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - /** Appends the string represent of this node and its children to the given StringBuilder. */ - protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { - builder.append(" " * depth) + /** + * Appends the string represent of this node and its children to the given StringBuilder. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + protected def generateTreeString( + depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + builder.append(simpleString) builder.append("\n") - children.foreach(_.generateTreeString(depth + 1, builder)) + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + builder } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala index 1090bdb5a4bd..6d35f140cf23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.catalyst.util /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala similarity index 77% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index f6fa021adee9..d85b72ed83de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -42,12 +42,17 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte ArrayBasedMapData.toScalaMap(this).hashCode() } - override def toString(): String = { + override def toString: String = { s"keys: $keyArray, values: $valueArray" } } object ArrayBasedMapData { + def apply(map: Map[Any, Any]): ArrayBasedMapData = { + val array = map.toArray + ArrayBasedMapData(array.map(_._1), array.map(_._2)) + } + def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } @@ -57,4 +62,17 @@ object ArrayBasedMapData { val values = map.valueArray.asInstanceOf[GenericArrayData].array keys.zip(values).toMap } + + def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = { + keys.zip(values).toMap + } + + def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + keys.zip(values).toMap + } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 642c56f12ded..cad4a08b0d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -15,17 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types.DataType abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int def copy(): ArrayData + def array: Array[Any] + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) @@ -103,6 +106,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { values } + def toObjectArray(elementType: DataType): Array[AnyRef] = + toArray[AnyRef](elementType: DataType) + def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() val values = new Array[T](size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 6e081ea9237b..515c071c283b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import scala.language.implicitConversions import scala.util.matching.Regex import scala.util.parsing.combinator.syntactical.StandardTokenParsers import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.types._ /** * This is a data type parser that can be used to parse string representations of data types @@ -51,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -59,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f645eb5f7bb0..2b9388291948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} +import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String @@ -31,11 +32,18 @@ import org.apache.spark.unsafe.types.UTF8String * precision. */ object DateTimeUtils { + + // we use Int and Long internally to represent [[DateType]] and [[TimestampType]] + type SQLDate = Int + type SQLTimestamp = Long + // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian - final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 + // it's 2440587.5, rounding up to compatible with Hive + final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L @@ -72,7 +80,7 @@ object DateTimeUtils { } // we should use the exact day as Int, for example, (year, month, day) -> day - def millisToDays(millisUtc: Long): Int = { + def millisToDays(millisUtc: Long): SQLDate = { // SPARK-6785: use Math.floor so negative number of days (dates before 1970) // will correctly work as input for function toJavaDate(Int) val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc) @@ -80,16 +88,16 @@ object DateTimeUtils { } // reverse of millisToDays - def daysToMillis(days: Int): Long = { + def daysToMillis(days: SQLDate): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) } - def dateToString(days: Int): String = + def dateToString(days: SQLDate): String = threadLocalDateFormat.get.format(toJavaDate(days)) // Converts Timestamp to string according to Hive TimestampWritable convention. - def timestampToString(us: Long): String = { + def timestampToString(us: SQLTimestamp): String = { val ts = toJavaTimestamp(us) val timestampString = ts.toString val formatted = threadLocalTimestampFormat.get.format(ts) @@ -102,51 +110,43 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { + val indexOfGMT = s.indexOf("GMT") + if (indexOfGMT != -1) { + // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) + val s0 = s.substring(0, indexOfGMT) + val s1 = s.substring(indexOfGMT + 3) + // Mapped to 2000-01-01T00:00+01:00 + stringToTime(s0 + s1) + } else if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { Timestamp.valueOf(s) } else { Date.valueOf(s) } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) + DatatypeConverter.parseDateTime(s).getTime() } } /** * Returns the number of days since epoch from from java.sql.Date. */ - def fromJavaDate(date: Date): Int = { + def fromJavaDate(date: Date): SQLDate = { millisToDays(date.getTime) } /** * Returns a java.sql.Date from number of days since epoch. */ - def toJavaDate(daysSinceEpoch: Int): Date = { + def toJavaDate(daysSinceEpoch: SQLDate): Date = { new Date(daysToMillis(daysSinceEpoch)) } /** * Returns a java.sql.Timestamp from number of micros since epoch. */ - def toJavaTimestamp(us: Long): Timestamp = { + def toJavaTimestamp(us: SQLTimestamp): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be // cut off at seconds var seconds = us / MICROS_PER_SECOND @@ -164,7 +164,7 @@ object DateTimeUtils { /** * Returns the number of micros since epoch from java.sql.Timestamp. */ - def fromJavaTimestamp(t: Timestamp): Long = { + def fromJavaTimestamp(t: Timestamp): SQLTimestamp = { if (t != null) { t.getTime() * 1000L + (t.getNanos().toLong / 1000) % 1000L } else { @@ -176,21 +176,22 @@ object DateTimeUtils { * Returns the number of microseconds since epoch from Julian day * and nanoseconds in a day */ - def fromJulianDay(day: Int, nanoseconds: Long): Long = { + def fromJulianDay(day: Int, nanoseconds: Long): SQLTimestamp = { // use Long to avoid rounding errors - val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 + val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY seconds * MICROS_PER_SECOND + nanoseconds / 1000L } /** * Returns Julian day and nanoseconds in a day from the number of microseconds + * + * Note: support timestamp since 4717 BC (without negative nanoseconds, compatible with Hive). */ - def toJulianDay(us: Long): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2 - val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH - val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (us % MICROS_PER_SECOND) * 1000L - (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + def toJulianDay(us: SQLTimestamp): (Int, Long) = { + val julian_us = us + JULIAN_DAY_OF_EPOCH * MICROS_PER_DAY + val day = julian_us / MICROS_PER_DAY + val micros = julian_us % MICROS_PER_DAY + (day.toInt, micros * 1000L) } /** @@ -219,7 +220,7 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ - def stringToTimestamp(s: UTF8String): Option[Long] = { + def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { if (s == null) { return None } @@ -240,6 +241,10 @@ object DateTimeUtils { i += 3 } else if (i < 2) { if (b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -307,17 +312,26 @@ object DateTimeUtils { } segments(i) = currentSegmentValue + if (!justTime && i == 0 && j != 4) { + // year should have exact four digits + return None + } while (digitsMilli < 6) { segments(6) *= 10 digitsMilli += 1 } - if (!justTime && (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || + if (!justTime && (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { @@ -355,7 +369,7 @@ object DateTimeUtils { * `yyyy-[m]m-[d]d *` * `yyyy-[m]m-[d]dT*` */ - def stringToDate(s: UTF8String): Option[Int] = { + def stringToDate(s: UTF8String): Option[SQLDate] = { if (s == null) { return None } @@ -367,6 +381,10 @@ object DateTimeUtils { while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { val b = bytes(j) if (i < 2 && b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -380,8 +398,12 @@ object DateTimeUtils { } j += 1 } + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue - if (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || + if (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31) { return None } @@ -391,29 +413,38 @@ object DateTimeUtils { Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + /** + * Returns the microseconds since year zero (-17999) from microseconds since epoch. + */ + private def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + microsec + toYearZero * MICROS_PER_DAY + } + + private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = { + absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ - def getHours(timestamp: Long): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 3600) % 24).toInt + def getHours(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND / 3600) % 24).toInt } /** * Returns the minute value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getMinutes(timestamp: Long): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 60) % 60).toInt + def getMinutes(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND / 60) % 60).toInt } /** * Returns the second value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getSeconds(timestamp: Long): Int = { - ((timestamp / 1000 / 1000) % 60).toInt + def getSeconds(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { @@ -447,7 +478,7 @@ object DateTimeUtils { * The calculation uses the fact that the period 1.1.2001 until 31.12.2400 is * equals to the period 1.1.1601 until 31.12.2000. */ - private[this] def getYearAndDayInYear(daysSince1970: Int): (Int, Int) = { + private[this] def getYearAndDayInYear(daysSince1970: SQLDate): (Int, Int) = { // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) val daysNormalized = daysSince1970 + toYearZero val numOfQuarterCenturies = daysNormalized / daysIn400Years @@ -461,7 +492,7 @@ object DateTimeUtils { * Returns the 'day in year' value for the given date. The date is expressed in days * since 1.1.1970. */ - def getDayInYear(date: Int): Int = { + def getDayInYear(date: SQLDate): Int = { getYearAndDayInYear(date)._2 } @@ -469,7 +500,7 @@ object DateTimeUtils { * Returns the year value for the given date. The date is expressed in days * since 1.1.1970. */ - def getYear(date: Int): Int = { + def getYear(date: SQLDate): Int = { getYearAndDayInYear(date)._1 } @@ -477,7 +508,7 @@ object DateTimeUtils { * Returns the quarter for the given date. The date is expressed in days * since 1.1.1970. */ - def getQuarter(date: Int): Int = { + def getQuarter(date: SQLDate): Int = { var (year, dayInYear) = getYearAndDayInYear(date) if (isLeapYear(year)) { dayInYear = dayInYear - 1 @@ -493,11 +524,55 @@ object DateTimeUtils { } } + /** + * Split date (expressed in days since 1.1.1970) into four fields: + * year, month (Jan is Month 1), dayInMonth, daysToMonthEnd (0 if it's last day of month). + */ + def splitDate(date: SQLDate): (Int, Int, Int, Int) = { + var (year, dayInYear) = getYearAndDayInYear(date) + val isLeap = isLeapYear(year) + if (isLeap && dayInYear == 60) { + (year, 2, 29, 0) + } else { + if (isLeap && dayInYear > 60) dayInYear -= 1 + + if (dayInYear <= 181) { + if (dayInYear <= 31) { + (year, 1, dayInYear, 31 - dayInYear) + } else if (dayInYear <= 59) { + (year, 2, dayInYear - 31, if (isLeap) 60 - dayInYear else 59 - dayInYear) + } else if (dayInYear <= 90) { + (year, 3, dayInYear - 59, 90 - dayInYear) + } else if (dayInYear <= 120) { + (year, 4, dayInYear - 90, 120 - dayInYear) + } else if (dayInYear <= 151) { + (year, 5, dayInYear - 120, 151 - dayInYear) + } else { + (year, 6, dayInYear - 151, 181 - dayInYear) + } + } else { + if (dayInYear <= 212) { + (year, 7, dayInYear - 181, 212 - dayInYear) + } else if (dayInYear <= 243) { + (year, 8, dayInYear - 212, 243 - dayInYear) + } else if (dayInYear <= 273) { + (year, 9, dayInYear - 243, 273 - dayInYear) + } else if (dayInYear <= 304) { + (year, 10, dayInYear - 273, 304 - dayInYear) + } else if (dayInYear <= 334) { + (year, 11, dayInYear - 304, 334 - dayInYear) + } else { + (year, 12, dayInYear - 334, 365 - dayInYear) + } + } + } + } + /** * Returns the month value for the given date. The date is expressed in days * since 1.1.1970. January is month 1. */ - def getMonth(date: Int): Int = { + def getMonth(date: SQLDate): Int = { var (year, dayInYear) = getYearAndDayInYear(date) if (isLeapYear(year)) { if (dayInYear == 60) { @@ -538,7 +613,7 @@ object DateTimeUtils { * Returns the 'day of month' value for the given date. The date is expressed in days * since 1.1.1970. */ - def getDayOfMonth(date: Int): Int = { + def getDayOfMonth(date: SQLDate): Int = { var (year, dayInYear) = getYearAndDayInYear(date) if (isLeapYear(year)) { if (dayInYear == 60) { @@ -584,7 +659,7 @@ object DateTimeUtils { * Returns the date value for the first day of the given month. * The month is expressed in months since year zero (17999 BC), starting from 0. */ - private def firstDayOfMonth(absoluteMonth: Int): Int = { + private def firstDayOfMonth(absoluteMonth: Int): SQLDate = { val absoluteYear = absoluteMonth / 12 var monthInYear = absoluteMonth - absoluteYear * 12 var date = getDateFromYear(absoluteYear) @@ -602,7 +677,7 @@ object DateTimeUtils { * Returns the date value for January 1 of the given year. * The year is expressed in years since year zero (17999 BC), starting from 0. */ - private def getDateFromYear(absoluteYear: Int): Int = { + private def getDateFromYear(absoluteYear: Int): SQLDate = { val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 + absoluteYear / 4) absoluteDays - toYearZero @@ -612,16 +687,17 @@ object DateTimeUtils { * Add date and year-month interval. * Returns a date value, expressed in days since 1.1.1970. */ - def dateAddMonths(days: Int, months: Int): Int = { - val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + def dateAddMonths(days: SQLDate, months: Int): SQLDate = { + val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days) + val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 val currentMonthInYear = nonNegativeMonth % 12 val currentYear = nonNegativeMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay - val dayOfMonth = getDayOfMonth(days) - val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) { // last day of the month lastDayOfMonth } else { @@ -634,52 +710,12 @@ object DateTimeUtils { * Add timestamp and full interval. * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. */ - def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = { + def timestampAddInterval(start: SQLTimestamp, months: Int, microseconds: Long): SQLTimestamp = { val days = millisToDays(start / 1000L) val newDays = dateAddMonths(days, months) daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds } - /** - * Returns the last dayInMonth in the month it belongs to. The date is expressed - * in days since 1.1.1970. the return value starts from 1. - */ - private def getLastDayInMonthOfMonth(date: Int): Int = { - var (year, dayInYear) = getYearAndDayInYear(date) - if (isLeapYear(year)) { - if (dayInYear > 31 && dayInYear <= 60) { - return 29 - } else if (dayInYear > 60) { - dayInYear = dayInYear - 1 - } - } - if (dayInYear <= 31) { - 31 - } else if (dayInYear <= 59) { - 28 - } else if (dayInYear <= 90) { - 31 - } else if (dayInYear <= 120) { - 30 - } else if (dayInYear <= 151) { - 31 - } else if (dayInYear <= 181) { - 30 - } else if (dayInYear <= 212) { - 31 - } else if (dayInYear <= 243) { - 31 - } else if (dayInYear <= 273) { - 30 - } else if (dayInYear <= 304) { - 31 - } else if (dayInYear <= 334) { - 30 - } else { - 31 - } - } - /** * Returns number of months between time1 and time2. time1 and time2 are expressed in * microseconds since 1.1.1970. @@ -690,19 +726,18 @@ object DateTimeUtils { * Otherwise, the difference is calculated based on 31 days per month, and rounding to * 8 digits. */ - def monthsBetween(time1: Long, time2: Long): Double = { + def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1) val date2 = millisToDays(millis2) - // TODO(davies): get year, month, dayOfMonth from single function - val dayInMonth1 = getDayOfMonth(date1) - val dayInMonth2 = getDayOfMonth(date2) - val months1 = getYear(date1) * 12 + getMonth(date1) - val months2 = getYear(date2) * 12 + getMonth(date2) - - if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) - && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1) + val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2) + + val months1 = year1 * 12 + monthInYear1 + val months2 = year2 * 12 + monthInYear2 + + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { return (months1 - months2).toDouble } // milliseconds is enough for 8 digits precision on the right side @@ -736,7 +771,7 @@ object DateTimeUtils { * Returns the first date which is later than startDate and is of the given dayOfWeek. * dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,. */ - def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = { + def getNextDateForDayOfWeek(startDate: SQLDate, dayOfWeek: Int): SQLDate = { startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7 } @@ -744,41 +779,9 @@ object DateTimeUtils { * Returns last day of the month for the given date. The date is expressed in days * since 1.1.1970. */ - def getLastDayOfMonth(date: Int): Int = { - var (year, dayInYear) = getYearAndDayInYear(date) - if (isLeapYear(year)) { - if (dayInYear > 31 && dayInYear <= 60) { - return date + (60 - dayInYear) - } else if (dayInYear > 60) { - dayInYear = dayInYear - 1 - } - } - val lastDayOfMonthInYear = if (dayInYear <= 31) { - 31 - } else if (dayInYear <= 59) { - 59 - } else if (dayInYear <= 90) { - 90 - } else if (dayInYear <= 120) { - 120 - } else if (dayInYear <= 151) { - 151 - } else if (dayInYear <= 181) { - 181 - } else if (dayInYear <= 212) { - 212 - } else if (dayInYear <= 243) { - 243 - } else if (dayInYear <= 273) { - 273 - } else if (dayInYear <= 304) { - 304 - } else if (dayInYear <= 334) { - 334 - } else { - 365 - } - date + (lastDayOfMonthInYear - dayInYear) + def getLastDayOfMonth(date: SQLDate): SQLDate = { + val (_, _, _, daysToMonthEnd) = splitDate(date) + date + daysToMonthEnd } private val TRUNC_TO_YEAR = 1 @@ -789,7 +792,7 @@ object DateTimeUtils { * Returns the trunc date from original date and trunc level. * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2. */ - def truncDate(d: Int, level: Int): Int = { + def truncDate(d: SQLDate, level: Int): SQLDate = { if (level == TRUNC_TO_YEAR) { d - DateTimeUtils.getDayInYear(d) + 1 } else if (level == TRUNC_TO_MONTH) { @@ -820,7 +823,7 @@ object DateTimeUtils { * Returns a timestamp of given timezone from utc timestamp, with the same string * representation in their timezone. */ - def fromUTCTime(time: Long, timeZone: String): Long = { + def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { val tz = TimeZone.getTimeZone(timeZone) val offset = tz.getOffset(time / 1000L) time + offset * 1000L @@ -830,7 +833,7 @@ object DateTimeUtils { * Returns a utc timestamp from a given timestamp from a given timezone, with the same * string representation in their timezone. */ - def toUTCTime(time: Long, timeZone: String): Long = { + def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { val tz = TimeZone.getTimeZone(timeZone) val offset = tz.getOffset(time / 1000L) time - offset * 1000L diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala similarity index 56% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index b314acdfe364..2b8cdc1e23ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -15,24 +15,50 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util -import scala.reflect.ClassTag +import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -class GenericArrayData(private[sql] val array: Array[Any]) - extends ArrayData with GenericSpecializedGetters { +class GenericArrayData(val array: Array[Any]) extends ArrayData { - override def genericGet(ordinal: Int): Any = array(ordinal) + def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + // TODO: This is boxing. We should specialize. + def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) - // todo: Array is invariant in scala, maybe use toSeq instead? - override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T]) + override def copy(): ArrayData = new GenericArrayData(array.clone()) override def numElements(): Int = array.length + private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def toString(): String = array.mkString("[", ",", "]") override def equals(o: Any): Boolean = { @@ -56,8 +82,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) return false } if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) + val o1 = array(i) + val o2 = other.array(i) o1 match { case b1: Array[Byte] => if (!o2.isInstanceOf[Array[Byte]] || @@ -91,7 +117,7 @@ class GenericArrayData(private[sql] val array: Array[Any]) if (isNullAt(i)) { 0 } else { - genericGet(i) match { + array(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala index f50969f0f0b7..40db6067adf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.types.DataType abstract class MapData extends Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d375..c2eeb3c5650a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0b41f92c6193..f603cbfb0cc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -54,10 +54,11 @@ object TypeUtils { def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - def getOrdering(t: DataType): Ordering[Any] = { + def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case s: StructType => s.ordering.asInstanceOf[Ordering[Any]] + case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e0667c629486..a5ae8bb0e5eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -84,6 +84,7 @@ private[sql] object TypeCollection { * Types that can be ordered/compared. In the long run we should probably make this a trait * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. */ + // TODO: Should we consolidate this with RowOrdering.isOrderable? val Ordered = TypeCollection( BooleanType, ByteType, ShortType, IntegerType, LongType, @@ -126,7 +127,7 @@ protected[sql] object AnyDataType extends AbstractDataType { */ protected[sql] abstract class AtomicType extends DataType { private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val tag: TypeTag[InternalType] private[sql] val ordering: Ordering[InternalType] @transient private[sql] val classTag = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2..a001eadcc61d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.ArrayData import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi +import scala.math.Ordering + object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -75,6 +78,55 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } + + @transient + private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { + private[this] val elementOrdering: Ordering[Any] = elementType match { + case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] + case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") + } + + def compare(x: ArrayData, y: ArrayData): Int = { + val leftArray = x + val rightArray = y + val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) + var i = 0 + while (i < minLength) { + val isNullLeft = leftArray.isNullAt(i) + val isNullRight = rightArray.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = + elementOrdering.compare( + leftArray.get(i, elementType), + rightArray.get(i, elementType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + if (leftArray.numElements() < rightArray.numElements()) { + return -1 + } else if (leftArray.numElements() > rightArray.numElements()) { + return 1 + } else { + return 0 + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b20..4b54c31dcc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -51,7 +51,9 @@ abstract class DataType extends AbstractDataType { def defaultSize: Int /** Name of the type used in JSON serialization. */ - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + def typeName: String = { + this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + } private[sql] def jsonValue: JValue = typeName @@ -77,6 +79,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c0155eeb450a..c7a1a2e7469e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{RoundingMode, MathContext} + import org.apache.spark.annotation.DeveloperApi /** @@ -28,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi * - Otherwise, the decimal value is longVal / (10 ** _scale) */ final class Decimal extends Ordered[Decimal] with Serializable { - import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} + import org.apache.spark.sql.types.Decimal._ private var decimalVal: BigDecimal = null private var longVal: Long = 0L @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(unscaled) + this.decimalVal = BigDecimal(unscaled, scale) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -105,8 +107,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) - require(decimalVal.precision <= precision, "Overflowed precision") + this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP) + require( + decimalVal.precision <= precision, + s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") this.longVal = 0L this._precision = precision this._scale = scale @@ -143,7 +147,13 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() + def toJavaBigDecimal: java.math.BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal.underlying() + } else { + java.math.BigDecimal.valueOf(longVal, _scale) + } + } def toUnscaledLong: Long = { if (decimalVal.ne(null)) { @@ -188,6 +198,16 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + changePrecision(precision, scale, ROUND_HALF_UP) + } + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + private[sql] def changePrecision(precision: Int, scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -221,7 +241,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, roundMode) if (newVal.precision > precision) { return false } @@ -261,33 +281,64 @@ final class Decimal extends Ordered[Decimal] with Serializable { def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + def + (that: Decimal): Decimal = { + if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { + Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale) + } else { + Decimal(toBigDecimal + that.toBigDecimal) + } + } - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + def - (that: Decimal): Decimal = { + if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { + Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale) + } else { + Decimal(toBigDecimal - that.toBigDecimal) + } + } - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + // HiveTypeCoercion will take care of the precision, scale of result + def * (that: Decimal): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { if (decimalVal.ne(null)) { - Decimal(-decimalVal) + Decimal(-decimalVal, precision, scale) } else { Decimal(-longVal, precision, scale) } } - def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this + def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + + def floor: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) + value + } + + def ceil: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) + value + } } object Decimal { - private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP + val ROUND_CEILING = BigDecimal.RoundingMode.CEILING + val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 @@ -296,6 +347,11 @@ object Decimal { private val BIG_DEC_ZERO = BigDecimal(0) + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + + private[sql] val ZERO = Decimal(0) + private[sql] val ONE = Decimal(1) + def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) @@ -309,6 +365,9 @@ object Decimal { def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) + def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = new Decimal().set(unscaled, precision, scale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa92..ce45245b9f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { case _ => false } + /** + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. + */ + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + /** * The default size of a value of the DecimalType is 4096 bytes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827c..00461e529ca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala new file mode 100644 index 000000000000..fca0b799eb80 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.language.existentials + +private[sql] object ObjectType extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = + throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + + // No casting or comparison is supported. + override private[sql] def acceptsType(other: DataType): Boolean = false + + override private[sql] def simpleString: String = "Object" +} + +/** + * Represents a JVM object that is passing through Spark SQL expression evaluation. Note this + * is only used internally while converting into the internal format and is not intended for use + * outside of the execution engine. + */ +private[sql] case class ObjectType(cls: Class[_]) extends DataType { + override def defaultSize: Int = + throw new UnsupportedOperationException("No size estimation available for objects.") + + def asNullable: DataType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6928707f7bf6..9778df271ddd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer -import scala.math.max import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} +import org.apache.spark.sql.catalyst.util.DataTypeParser /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } - private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { @@ -322,7 +328,8 @@ object StructType extends AbstractDataType { def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) + import scala.collection.JavaConverters._ + StructType(fields.asScala) } protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = @@ -367,10 +374,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 11e0c120f407..7614f055e9c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -23,6 +23,8 @@ import java.math.MathContext import scala.util.Random +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -84,6 +86,7 @@ object RandomDataGenerator { * random data generator is defined for that data type. The generated values will use an external * representation of the data type; for example, the random generator for [[DateType]] will return * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. + * For a [[UserDefinedType]] for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated @@ -105,8 +108,37 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) - case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) + case TimestampType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + // DateTimeUtils.toJavaTimestamp takes microsecond. + DateTimeUtils.toJavaTimestamp(milliseconds * 1000) + } + Some(generator) case CalendarIntervalType => Some(() => { val months = rand.nextInt(1000) val ns = rand.nextLong() @@ -116,7 +148,7 @@ object RandomDataGenerator { () => BigDecimal.apply( rand.nextLong() % math.pow(10, precision).toLong, scale, - new MathContext(precision))) + new MathContext(precision)).bigDecimal) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) @@ -134,7 +166,7 @@ object RandomDataGenerator { case NullType => Some(() => null) case ArrayType(elementType, containsNull) => { forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { - elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) + elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } } case MapType(keyType, valueType, valueContainsNull) => { @@ -159,6 +191,27 @@ object RandomDataGenerator { None } } + case udt: UserDefinedType[_] => { + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + // Because random data generator at here returns scala value, we need to + // convert it to catalyst value to call udt's deserialize. + val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) + + if (maybeSqlTypeGenerator.isDefined) { + val sqlTypeGenerator = maybeSqlTypeGenerator.get + val generator = () => { + val generatedScalaValue = sqlTypeGenerator.apply() + if (generatedScalaValue == null) { + null + } else { + udt.deserialize(toCatalystType(generatedScalaValue)) + } + } + Some(generator) + } else { + None + } + } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 01ff84cb5605..5c22a7219254 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers { StructField("col2", StringType) :: StructField("col3", IntegerType) :: Nil) val values = Array("value1", "value2", 1) + val valuesWithoutCol3 = Array[Any](null, "value2", null) val sampleRow: Row = new GenericRowWithSchema(values, schema) + val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema) val noSchemaRow: Row = new GenericRow(values) describe("Row (without schema)") { @@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers { ) sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } + + it("getValuesMap() retrieves null value on non AnyVal Type") { + val expected = Map( + "col1" -> null, + "col2" -> "value2" + ) + sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected + } + + it("getAs() on type extending AnyVal throws an exception when accessing field that is null") { + intercept[NullPointerException] { + sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3")) + } + } + + it("getAs() on type extending AnyVal does not throw exception when value is null"){ + sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null + } } describe("row equals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index df0f04563edc..03bb102c67fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -32,7 +32,9 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + DecimalType.SYSTEM_DEFAULT, + DecimalType.USER_DEFAULT) test("null handling in rows") { val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 000000000000..5b802ccc637d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737..c2aace1ef238 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) @@ -280,4 +212,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index b93a3abc6ebd..9ff893b84775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.Command +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -49,7 +50,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends SparkFunSuite { +class SqlParserSuite extends PlanTest { test("test long keyword") { val parser = new SuperLongKeywordTestParser @@ -63,4 +64,87 @@ class SqlParserSuite extends SparkFunSuite { assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) } + + test("test NOT operator with comparison operations") { + val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") + val expected = Project( + UnresolvedAlias( + Not( + GreaterThan(Literal(true), Literal(true))) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = SqlParser.parse(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = SqlParser.parse("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight(".9e+2", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[RuntimeException](SqlParser.parse("SELECT .e3")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 26935c6e3b24..12079992b5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,18 +17,74 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, Sum, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ +import scala.beans.{BeanProperty, BeanInfo} + +@BeanInfo +private[sql] case class GroupableData(@BeanProperty data: Int) + +private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { + + override def sqlType: DataType = IntegerType + + override def serialize(obj: Any): Int = { + obj match { + case groupableData: GroupableData => groupableData.data + } + } + + override def deserialize(datum: Any): GroupableData = { + datum match { + case data: Int => GroupableData(data) + } + } + + override def userClass: Class[GroupableData] = classOf[GroupableData] + + private[spark] override def asNullable: GroupableUDT = this +} + +@BeanInfo +private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) + +private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { + + override def sqlType: DataType = MapType(IntegerType, IntegerType) + + override def serialize(obj: Any): MapData = { + obj match { + case groupableData: UngroupableData => + val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(groupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) + } + } + + override def deserialize(datum: Any): UngroupableData = { + datum match { + case data: MapData => + val keyArray = data.keyArray().array + val valueArray = data.valueArray().array + assert(keyArray.length == valueArray.length) + val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] + UngroupableData(mapData) + } + } + + override def userClass: Class[UngroupableData] = classOf[UngroupableData] + + private[spark] override def asNullable: UngroupableUDT = this +} + case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -42,8 +98,8 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisErrorSuite extends AnalysisTest { + import TestRelations._ def errorTest( name: String, @@ -51,15 +107,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorMessages: Seq[String], caseSensitive: Boolean = true): Unit = { test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) + assertAnalysisError(plan, errorMessages, caseSensitive) } } @@ -69,34 +117,54 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" ::Nil) + "requires int type" :: "'null' is of date type" :: Nil) errorTest( - "unresolved window function", + "invalid window function", testRelation2.select( WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), + Literal(0), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) + "not supported within a window function" :: Nil) + + errorTest( + "distinct window function", + testRelation2.select( + WindowExpression( + AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "Distinct window functions are not supported" :: Nil) + + errorTest( + "offset window function", + testRelation2.select( + WindowExpression( + new Lead(UnresolvedAttribute("b")), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + "window frame" :: "must match the required frame" :: Nil) errorTest( "too many generators", @@ -115,8 +183,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "sorting by unsupported column types", - listRelation.orderBy('list.asc), - "sort" :: "type" :: "array" :: Nil) + mapRelation.orderBy('map.asc), + "sort" :: "type" :: "map" :: Nil) errorTest( "non-boolean filters", @@ -157,23 +225,44 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { UnresolvedTestPlan(), "unresolved" :: Nil) + errorTest( + "union with unequal number of columns", + testRelation.unionAll(testRelation2), + "union" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "intersect with unequal number of columns", + testRelation.intersect(testRelation2), + "intersect" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "except with unequal number of columns", + testRelation.except(testRelation2), + "except" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "SPARK-9955: correct error message for aggregate", + // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. + testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), + "cannot resolve 'bad_column'" :: Nil) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + // Since we manually construct the logical plan at here and Sum only accetp + // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + AttributeReference("a", LongType)(exprId = ExprId(2)))) assert(plan.resolved) - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) } test("error test for self-join") { @@ -185,32 +274,66 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { assert(error.message.contains("Conflicting attributes")) } - test("aggregation can't work on binary and map types") { - val plan = - Aggregate( - AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + test("check grouping expression data types") { + def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = { + val plan = + Aggregate( + AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", dataType)(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + shouldSuccess match { + case true => + assertAnalysisSuccess(plan, true) + case false => + assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil) + } + } - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) + val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", StringType, nullable = true), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new GroupableUDT()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) } - assert(error.message.contains("binary type expression a cannot be used in grouping expression")) - val plan2 = + val unsupportedDataTypes = Seq( + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", MapType(StringType, LongType), nullable = true), + new UngroupableUDT()) + unsupportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = false) + } + } + + test("we should fail analysis when we find nested aggregate functions") { + val plan = Aggregate( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil, LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), + AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in grouping expression")) + assertAnalysisError( + plan, + "It is not allowed to use an aggregate function in the argument of " + + "another aggregate function." :: Nil) } test("Join can't work on binary and map types") { @@ -226,10 +349,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - } - assert(error.message.contains("binary type expression a cannot be used in join conditions")) + assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -243,9 +363,6 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in join conditions")) + assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a86cefe941e8..aeeca802d8bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,68 +17,16 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - -// todo: remove this and use AnalysisTest instead. -object AnalysisSuite { - val caseSensitiveConf = new SimpleCatalystConf(true) - val caseInsensitiveConf = new SimpleCatalystConf(false) - - val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) - val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - - val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - - def caseSensitiveAnalyze(plan: LogicalPlan): Unit = - caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) - - def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = - caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) - - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) - - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) -} - class AnalysisSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { val plan = (1 to 100) @@ -98,7 +46,7 @@ class AnalysisSuite extends AnalysisTest { val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) + assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { @@ -107,32 +55,39 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output, testRelation)) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation)) assertAnalysisError( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Seq("cannot resolve")) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation), caseSensitive = false) } test("resolve relations") { - assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) + assertAnalysisError( + UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) - checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) + checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) - checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) + checkAnalysis( + UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) - checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) + checkAnalysis( + UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) } test("divide should be casted into fractional types") { @@ -149,7 +104,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 29)) + assert(pl(3).dataType == DecimalType(38, 22)) assert(pl(4).dataType == DoubleType) } @@ -165,39 +120,158 @@ class AnalysisSuite extends AnalysisTest { test("pull out nondeterministic expressions from Sort") { val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) - val analyzed = caseSensitiveAnalyzer.execute(plan) - analyzed.transform { - case s: Sort if s.expressions.exists(!_.deterministic) => - fail("nondeterministic expressions are not allowed in Sort") - } + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false, + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) } - test("remove still-need-evaluate ordering expressions from sort") { + test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { + val a = testRelation.output.head + var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) + var expected = testRelation.select((a + 1 + 2).as("col")) + checkAnalysis(plan, expected) + + plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) + expected = testRelation.groupBy(a)((min(a) + 1).as("col")) + checkAnalysis(plan, expected) + + // CreateStruct is a special case that we should not trim Alias for it. + plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + } + + test("SPARK-10534: resolve attribute references in order by clause") { val a = testRelation2.output(0) - val b = testRelation2.output(1) + val c = testRelation2.output(2) + + val plan = testRelation2.select('c).orderBy(Floor('a).asc) + val expected = testRelation2.select(c, a).orderBy(Floor(a.cast(DoubleType)).asc).select(c) - def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending) + checkAnalysis(plan, expected) + } - val noEvalOrdering = makeOrder(a) - val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")()) + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, + LocalRelation() + ) + assertAnalysisSuccess(plan) + } - val needEvalExpr = Coalesce(Seq(a, Literal("1"))) - val needEvalExpr2 = Coalesce(Seq(a, b)) - val needEvalOrdering = makeOrder(needEvalExpr) - val needEvalOrdering2 = makeOrder(needEvalExpr2) + test("SPARK-8654: different types in inlist but can be converted to a commmon type") { + val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, + LocalRelation() + ) + assertAnalysisSuccess(plan) + } - val plan = Sort( - Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2), - false, testRelation2) + test("SPARK-8654: check type compatibility error") { + val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, + LocalRelation() + ) + assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) + } - val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)()) - val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")()) + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } - val expected = - Project(testRelation2.output, - Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false, - Project(testRelation2.output ++ materializedExprs, testRelation2))) + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + val alias1 = a.as("a1") + val alias2 = c.as("a2") + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) + .orderBy('a1.asc, 'c.asc) + + val expected = testRelation2 + .groupBy(a, c)(alias1, alias2, alias3) + .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) checkAnalysis(plan, expected) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index fdb4f28950da..23861ed15da6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,39 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} trait AnalysisTest extends PlanTest { - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { val caseSensitiveConf = new SimpleCatalystConf(true) @@ -59,8 +31,8 @@ trait AnalysisTest extends PlanTest { val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseSensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) + caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { override val extendedResolutionRules = EliminateSubQueries :: Nil @@ -96,10 +68,11 @@ trait AnalysisTest extends PlanTest { expectedErrors: Seq[String], caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - // todo: make sure we throw AnalysisException during analysis - val e = intercept[Exception] { + val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - expectedErrors.forall(e.getMessage.contains) + assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), + s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + + s"actually we get ${e.getMessage}") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fc11627da6fd..fed591fd90a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -21,9 +21,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) @@ -47,7 +48,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val b: Expression = UnresolvedAttribute("b") before { - catalog.registerTable(Seq("table"), relation) + catalog.registerTable(TableIdentifier("table"), relation) } private def checkType(expression: Expression, expectedType: DataType): Unit = { @@ -136,10 +137,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Multiply(i, u), DecimalType(38, 18)) checkType(Multiply(u, u), DecimalType(38, 36)) - checkType(Divide(u, d1), DecimalType(38, 21)) - checkType(Divide(u, d2), DecimalType(38, 24)) - checkType(Divide(u, i), DecimalType(38, 29)) - checkType(Divide(u, u), DecimalType(38, 38)) + checkType(Divide(u, d1), DecimalType(38, 18)) + checkType(Divide(u, d2), DecimalType(38, 19)) + checkType(Divide(u, i), DecimalType(38, 23)) + checkType(Divide(u, u), DecimalType(38, 18)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c9bcc68f0203..915c585ec91f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -22,8 +22,9 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -31,7 +32,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, - 'complexField.array(StringType)) + 'decimalField.decimal(8, 0), + 'arrayField.array(StringType), + 'mapField.map(StringType, LongType)) def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { @@ -89,9 +92,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type") assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type") - assertError(MaxOf('complexField, 'complexField), + assertError(MaxOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(MinOf('complexField, 'complexField), + assertError(MinOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") } @@ -108,20 +111,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) + assertErrorForDifferingTypes(EqualTo('intField, 'mapField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(LessThan('complexField, 'complexField), + assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(LessThanOrEqual('complexField, 'complexField), + assertError(LessThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThan('complexField, 'complexField), + assertError(GreaterThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThanOrEqual('complexField, 'complexField), + assertError(GreaterThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(If('intField, 'stringField, 'stringField), @@ -129,10 +132,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), @@ -140,15 +143,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for aggregates") { + // We use AggregateFunction directly at here because the error will be thrown from it + // instead of from AggregateExpression, which is the wrapper of an AggregateFunction. + // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) - assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) + assertSuccess(Min('arrayField)) - assertError(Min('complexField), "min does not support ordering on type") - assertError(Max('complexField), "max does not support ordering on type") + assertError(Min('mapField), "min does not support ordering on type") + assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } @@ -182,7 +187,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'intField), "Only foldable Expression is allowed") assertError(Round('intField, 'booleanField), "requires int type") - assertError(Round('intField, 'complexField), "requires int type") + assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index cbdf453f600a..142915056f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), @@ -285,6 +308,17 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Cast(Literal(100L), DecimalType(22, 2)), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + ) } test("type coercion simplification for equal to") { @@ -459,7 +493,8 @@ class HiveTypeCoercionSuite extends PlanTest { ) ruleTest(inConversion, In(Literal("a"), Seq(Literal(1), Literal("b"))), - In(Literal("a"), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) + In(Cast(Literal("a"), StringType), + Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala new file mode 100644 index 000000000000..bc07b609a341 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +object TestRelations { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val mapRelation = LocalRelation( + AttributeReference("map", MapType(IntegerType, IntegerType))()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 000000000000..8c766ef82992 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala new file mode 100644 index 000000000000..815a03f7c1a8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolutionSuite extends PlanTest { + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + val cls = classOf[StringLongClass] + + { + val attrs = Seq('a.string, 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + { + val attrs = Seq('a.int, 'b.long) + val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + val expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val innerCls = classOf[StringLongClass] + val cls = classOf[ComplexClass] + + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + Seq( + 'a.int.cast(LongType), + If( + 'b.struct('a.int, 'b.long).isNull, + Literal.create(null, ObjectType(innerCls)), + NewInstance( + innerCls, + Seq( + toExternalString( + GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), + GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), + false, + ObjectType(innerCls)) + )), + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val cls = classOf[StringLongClass] + + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + classOf[Tuple2[_, _]], + Seq( + NewInstance( + cls, + Seq( + toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), + false, + ObjectType(cls)), + 'b.int.cast(LongType)), + false, + ObjectType(classOf[Tuple2[_, _]])) + compareExpressions(fromRowExpr, expected) + } + + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + + private def toExternalString(e: Expression): Expression = { + Invoke(e, "toString", ObjectType(classOf[String]), Nil) + } + + test("throw exception if real type is not compatible with encoder schema") { + val msg1 = intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + }.message + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2 == + s""" + |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + castFail[String, Long] + + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala new file mode 100644 index 000000000000..7233e0f1b5ba --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.sql.{Timestamp, Date} +import java.util.Arrays +import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag + +import com.google.common.collect.MapMaker + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} +import org.apache.spark.sql.types.{StructType, ArrayType} + +case class RepeatedStruct(s: Seq[PrimitiveData]) + +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + encodeDecodeTest(Array(InnerClass(1)), "array of inner class") + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( + input: T, + testName: String): Unit = { + test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] + val row = encoder.toRow(input) + val schema = encoder.schema.toAttributes + val boundEncoder = encoder.resolve(schema, outers).bind(schema) + val convertedBack = try boundEncoder.fromRow(row) catch { + case e: Exception => + fail( + s"""Exception thrown while decoding + |Converted: $row + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |Encoder: + |$boundEncoder + | + """.stripMargin, e) + } + + val isCorrect = (input, convertedBack) match { + case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) + case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) + case (b1: Array[Array[_]], b2: Array[Array[_]]) => + Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: Array[_], b2: Array[_]) => + Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case _ => input == convertedBack + } + + if (!isCorrect) { + val types = convertedBack match { + case c: Product => + c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + case other => other.getClass.getName + } + + val encodedData = try { + row.toSeq(encoder.schema).zip(schema).map { + case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + a.toArray[Any](et).toSeq + case (other, _) => + other + }.mkString("[", ",", "]") + } catch { + case e: Throwable => s"Failed to toSeq: $e" + } + + fail( + s"""Encoded/Decoded data does not match input data + | + |in: $input + |out: $convertedBack + |types: $types + | + |Encoded Data: $encodedData + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |fromRow Expressions: + |${boundEncoder.fromRowExpression.treeString} + """.stripMargin) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala new file mode 100644 index 000000000000..0ea51ece4bc5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt + override def equals(that: Any): Boolean = { + if (that.isInstanceOf[ExamplePoint]) { + val e = that.asInstanceOf[ExamplePoint] + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + } else { + false + } + } +} + +/** + * User-defined type for [[ExamplePoint]]. + */ +class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + + override def serialize(obj: Any): GenericArrayData = { + obj match { + case p: ExamplePoint => + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: ArrayData => + if (values.numElements() > 1) { + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } else { + val random = new Random() + new ExamplePoint(random.nextDouble(), random.nextDouble()) + } + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this +} + +class RowEncoderSuite extends SparkFunSuite { + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + encodeDecodeTest( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT, false)) + + encodeDecodeTest( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString))) + + encodeDecodeTest( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + encodeDecodeTest( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + test(s"encode/decode: arrayOfUDT") { + val schema = new StructType() + .add("arrayOfUDT", arrayOfUDT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) + } + + test(s"encode/decode: Product") { + val schema = new StructType() + .add("structAsProduct", + new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType)) + + val encoder = RowEncoder(schema) + + val input: Row = Row((100, "test", 0.123)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getStruct(0) == convertedBack.getStruct(0)) + } + + private def encodeDecodeTest(schema: StructType): Unit = { + test(s"encode/decode: ${schema.simpleString}") { + val encoder = RowEncoder(schema) + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get + + var input: Row = null + try { + for (_ <- 1 to 5) { + input = inputGenerator.apply().asInstanceOf[Row] + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input == convertedBack) + } + } catch { + case e: Exception => + fail( + s""" + |schema: ${schema.simpleString} + |input: ${input} + """.stripMargin, e) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 0bae8fe2fd8a..72285c6a2419 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + import IntegralLiteralTestUtils._ + /** * Runs through the testFunc for all numeric data types. * @@ -47,6 +49,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(Literal.create(null, left.dataType), right), null) checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) + checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) + checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) + } } test("- (UnaryMinus)") { @@ -60,6 +69,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) + checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) + checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) + checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) + checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) + checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) + checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(UnaryMinus, tpe) + } } test("- (Minus)") { @@ -70,6 +89,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Subtract(positiveShortLit, negativeShortLit), + (positiveShort - negativeShort).toShort) + checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) + checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) + } } test("* (Multiply)") { @@ -80,6 +107,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) } + checkEvaluation(Multiply(positiveShortLit, negativeShortLit), + (positiveShort * negativeShort).toShort) + checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) + checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) + } } test("/ (Divide) basic") { @@ -92,6 +127,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) + } } test("/ (Divide) for integral type") { @@ -99,6 +138,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) checkEvaluation(Divide(Literal(1), Literal(2)), 0) checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort) + checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0) + checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) } test("/ (Divide) for floating point") { @@ -116,6 +158,18 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 } + checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort) + checkEvaluation(Remainder(negativeShortLit, negativeShortLit), 0.toShort) + checkEvaluation(Remainder(positiveIntLit, positiveIntLit), 0) + checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0) + checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) + checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) + + // TODO: the following lines would fail the test due to inconsistency result of interpret + // and codegen for remainder between giant values, seems like a numeric stability issue + // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + // checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) + // } } test("Abs") { @@ -127,6 +181,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(Literal(convert(-1))), convert(1)) checkEvaluation(Abs(Literal.create(null, dataType)), null) } + checkEvaluation(Abs(positiveShortLit), positiveShort) + checkEvaluation(Abs(negativeShortLit), (- negativeShort).toShort) + checkEvaluation(Abs(positiveIntLit), positiveInt) + checkEvaluation(Abs(negativeIntLit), - negativeInt) + checkEvaluation(Abs(positiveLongLit), positiveLong) + checkEvaluation(Abs(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe) + } } test("MaxOf basic") { @@ -138,6 +202,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) } + checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort) + checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt) + checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MaxOf, tpe, tpe) + } } test("MaxOf for atomic type") { @@ -156,6 +227,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) } + checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort) + checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt) + checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) + } } test("MinOf for atomic type") { @@ -174,9 +252,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 } - checkEvaluation(Pmod(-7, 3), 2) - checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) - checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2L) + checkEvaluation(Pmod(Literal(-7), Literal(3)), 2) + checkEvaluation(Pmod(Literal(7.2D), Literal(4.1D)), 3.1000000000000005) + checkEvaluation(Pmod(Literal(Decimal(0.7)), Literal(Decimal(0.2))), Decimal(0.1)) + checkEvaluation(Pmod(Literal(2L), Literal(Long.MaxValue)), 2L) + checkEvaluation(Pmod(positiveShort, negativeShort), positiveShort.toShort) + checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt) + checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong) + } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index fa30fbe52847..3a310c0e9a7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types._ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import IntegralLiteralTestUtils._ + test("BitwiseNOT") { def check(input: Any, expected: Any): Unit = { val expr = BitwiseNot(Literal(input)) @@ -37,6 +39,16 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { check(123456789123L, ~123456789123L) checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null) + checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort) + checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort) + checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt) + checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) + checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) + checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) + } } test("BitwiseAnd") { @@ -56,6 +68,14 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null) checkEvaluation(BitwiseAnd(Literal(1), nullLit), null) checkEvaluation(BitwiseAnd(nullLit, nullLit), null) + checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit), + (positiveShort & negativeShort).toShort) + checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) + checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) + } } test("BitwiseOr") { @@ -75,6 +95,14 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseOr(nullLit, Literal(1)), null) checkEvaluation(BitwiseOr(Literal(1), nullLit), null) checkEvaluation(BitwiseOr(nullLit, nullLit), null) + checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit), + (positiveShort | negativeShort).toShort) + checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) + checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) + } } test("BitwiseXor") { @@ -94,5 +122,13 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) checkEvaluation(BitwiseXor(Literal(1), nullLit), null) checkEvaluation(BitwiseXor(nullLit, nullLit), null) + checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), + (positiveShort ^ negativeShort).toShort) + checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) + checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733eae0..a98e16c25321 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) @@ -503,9 +505,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +524,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +543,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +559,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +579,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +602,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +632,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +643,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +674,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +706,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +714,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +733,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +757,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +771,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cc82f7c3f5a7..0c42e2fc7c5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math._ - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -40,7 +38,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val futures = (1 to 20).map { _ => future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) - GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } @@ -49,45 +46,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } - // Test GenerateOrdering for all common types. For each type, we construct random input rows that - // contain two columns of that type, then for pairs of randomly-generated rows we check that - // GenerateOrdering agrees with RowOrdering. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - test(s"GenerateOrdering with $dataType") { - val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) - val genOrdering = GenerateOrdering.generate( - BoundReference(0, dataType, nullable = true).asc :: - BoundReference(1, dataType, nullable = true).asc :: Nil) - val rowType = StructType( - StructField("a", dataType, nullable = true) :: - StructField("b", dataType, nullable = true) :: Nil) - val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) - assume(maybeDataGenerator.isDefined) - val randGenerator = maybeDataGenerator.get - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - for (_ <- 1 to 50) { - val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } - } - } - } - test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { @@ -134,4 +97,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) assert(internalRow === internalRow2) } + + test("*/ in the data") { + // When */ appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape */ to \*\/. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), + true, + InternalRow(UTF8String.fromString("*/"))) + } + + test("\\u in the data") { + // When \ u appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape \ u to \\u. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), + true, + InternalRow(UTF8String.fromString("\\u"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 2c7e85c446ec..1aae4678d627 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -49,6 +49,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) @@ -64,5 +65,32 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + checkEvaluation(new SortArray(a4), Seq(null, null)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayContains(a1, Literal("")), true) + checkEvaluation(ArrayContains(a1, Literal("a")), null) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContains(a2, Literal(1L)), null) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContains(a3, Literal("")), null) + checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423..9f1b19253e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } @@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index d26bcdb2902a..0df673bb9fa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -66,6 +66,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toLong, TimestampType) testIf(_.toString, StringType) + + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(If, BooleanType, dt, dt) + } } test("case when") { @@ -176,6 +180,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) + } } test("function greatest") { @@ -218,6 +226,9 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) - } + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index e6e8790e9092..53c66d8a754e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.unsafe.types.CalendarInterval class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import IntegralLiteralTestUtils._ + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val sdfDate = new SimpleDateFormat("yyyy-MM-dd") val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) @@ -58,6 +60,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) } test("Year") { @@ -77,6 +80,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) } test("Quarter") { @@ -96,6 +100,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) } test("Month") { @@ -115,6 +120,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Month, DateType) } test("Day / DayOfMonth") { @@ -133,6 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } + checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth, DateType) } test("Seconds") { @@ -147,6 +154,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } + checkConsistencyBetweenInterpretedAndCodegen(Second, TimestampType) } test("WeekOfYear") { @@ -155,6 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } test("DateFormat") { @@ -182,6 +191,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Hour, TimestampType) } test("Minute") { @@ -198,6 +208,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.MINUTE)) } } + checkConsistencyBetweenInterpretedAndCodegen(Minute, TimestampType) } test("date_add") { @@ -212,6 +223,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)), null) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType) } test("date_sub") { @@ -226,6 +242,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)), null) + checkEvaluation( + DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) + checkEvaluation( + DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) + checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType) } test("time_add") { @@ -244,6 +265,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeAdd, TimestampType, CalendarIntervalType) } test("time_sub") { @@ -267,6 +289,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeSub, TimestampType, CalendarIntervalType) } test("add_months") { @@ -282,6 +305,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) checkEvaluation( AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) + checkConsistencyBetweenInterpretedAndCodegen(AddMonths, DateType, IntegerType) } test("months_between") { @@ -306,6 +334,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthsBetween(t, tnull), null) checkEvaluation(MonthsBetween(tnull, t), null) checkEvaluation(MonthsBetween(tnull, tnull), null) + checkConsistencyBetweenInterpretedAndCodegen(MonthsBetween, TimestampType, TimestampType) } test("last_day") { @@ -323,6 +352,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(LastDay, DateType) } test("next_day") { @@ -356,6 +386,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ToDate(Literal(Date.valueOf("2015-07-22"))), DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) checkEvaluation(ToDate(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType) } test("function trunc") { @@ -434,6 +465,42 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } + test("to_unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + test("datediff") { checkEvaluation( DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala new file mode 100644 index 000000000000..511f0307901d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} + + +class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("UnscaledValue") { + val d1 = Decimal("10.1") + checkEvaluation(UnscaledValue(Literal(d1)), 101L) + val d2 = Decimal(101, 3, 1) + checkEvaluation(UnscaledValue(Literal(d2)), 101L) + checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null) + } + + test("MakeDecimal") { + checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) + checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + } + + test("PromotePrecision") { + val d1 = Decimal("10.1") + checkEvaluation(PromotePrecision(Literal(d1)), d1) + val d2 = Decimal(101, 3, 1) + checkEvaluation(PromotePrecision(Literal(d2)), d2) + checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null) + } + + test("CheckOverflow") { + val d1 = Decimal("10.1") + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + + val d2 = Decimal(101, 3, 1) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + + checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3e5515129874..f869a96edb1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,22 +17,23 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { self: SparkFunSuite => - protected val defaultOptimizer = new DefaultOptimizer - protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -42,7 +43,6 @@ trait ExpressionEvalHelper { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } @@ -119,42 +119,6 @@ trait ExpressionEvalHelper { } } - protected def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow) - val expectedRow = InternalRow(expected) - - // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our - // interpreted version. - if (actual.hashCode() != expectedRow.hashCode()) { - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: $expression - |Code: $evaluated - """.stripMargin) - } - - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail("Incorrect Evaluation in codegen mode: " + - s"$expression, actual: $actual, expected: $expectedRow$input") - } - if (actual.copy() != expectedRow) { - fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") - } - } - protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, @@ -188,7 +152,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } @@ -201,7 +165,7 @@ trait ExpressionEvalHelper { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) @@ -213,4 +177,111 @@ trait ExpressionEvalHelper { plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against unary expressions by feeding them arbitrary literals of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Expression => Expression, + dataType: DataType): Unit = { + forAll (LiteralGenerator.randomGen(dataType)) { (l: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against ternary expressions by feeding them arbitrary literals of `dataType1`, + * `dataType2` and `dataType3`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType, + dataType3: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2), + LiteralGenerator.randomGen(dataType3) + ) { (l1: Literal, l2: Literal, l3: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2, l3)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against expressions take Seq[Expression] as input by feeding them + * arbitrary length Seq of arbitrary literal of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Seq[Expression] => Expression, + dataType: DataType, + minNumElements: Int = 0): Unit = { + forAll (Gen.listOf(LiteralGenerator.randomGen(dataType))) { (literals: Seq[Literal]) => + whenever(literals.size >= minNumElements) { + cmpInterpretWithCodegen(EmptyRow, c(literals)) + } + } + } + + private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { + val interpret = try { + evaluate(expr, inputRow) + } catch { + case e: Exception => fail(s"Exception evaluating $expr", e) + } + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + expr) + val codegen = plan(inputRow).get(0, expr.dataType) + + if (!compareResults(interpret, codegen)) { + fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + } + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte] and Spread[Double]. + */ + private[this] def compareResults(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) + case (result: Double, expected: Double) if result.isNaN && expected.isNaN => + true + case (result: Float, expected: Float) if result.isNaN && expected.isNaN => + true + case _ => result == expected + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala new file mode 100644 index 000000000000..2e5a121f4ec5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Utilities to make sure we pass the proper numeric ranges + */ +object IntegralLiteralTestUtils { + + val positiveShort: Short = (Byte.MaxValue + 1).toShort + val negativeShort: Short = (Byte.MinValue - 1).toShort + + val positiveShortLit: Literal = Literal(positiveShort) + val negativeShortLit: Literal = Literal(negativeShort) + + val positiveInt: Int = Short.MaxValue + 1 + val negativeInt: Int = Short.MinValue - 1 + + val positiveIntLit: Literal = Literal(positiveInt) + val negativeIntLit: Literal = Literal(negativeInt) + + val positiveLong: Long = Int.MaxValue + 1L + val negativeLong: Long = Int.MinValue - 1L + + val positiveLongLit: Literal = Literal(positiveLong) + val negativeLongLit: Literal = Literal(negativeLong) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala new file mode 100644 index 000000000000..7b754091f471 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + val json = + """ + |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], + |"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees", + |"title":"Sayings of the Century","category":"reference","price":8.95}, + |{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99, + |"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings", + |"category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}], + |"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}}, + |"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025", + |"fb:testid":"1234"} + |""".stripMargin + + test("$.store.bicycle") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.bicycle")), + """{"price":19.95,"color":"red"}""") + } + + test("$.store.book") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book")), + """[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference", + |"price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction", + |"price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title": + |"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"}, + |{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] + |""".stripMargin.replace("\n", "")) + } + + test("$.store.book[0]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[0]")), + """{"author":"Nigel Rees","title":"Sayings of the Century", + |"category":"reference","price":8.95}""".stripMargin.replace("\n", "")) + } + + test("$.store.book[*]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[*]")), + """[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference", + |"price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction", + |"price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title": + |"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"}, + |{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] + |""".stripMargin.replace("\n", "")) + } + + test("$") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$")), + json.replace("\n", "")) + } + + test("$.store.book[0].category") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[0].category")), + "reference") + } + + test("$.store.book[*].category") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[*].category")), + """["reference","fiction","fiction"]""") + } + + test("$.store.book[*].isbn") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[*].isbn")), + """["0-553-21311-3","0-395-19395-8"]""") + } + + test("$.store.book[*].reader") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[*].reader")), + """[{"age":25,"name":"bob"},{"age":26,"name":"jack"}]""") + } + + test("$.store.basket[0][1]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[0][1]")), + "2") + } + + test("$.store.basket[*]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[*]")), + """[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]]""") + } + + test("$.store.basket[*][0]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[*][0]")), + "[1,3,5]") + } + + test("$.store.basket[0][*]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[0][*]")), + """[1,2,{"b":"y","a":"x"}]""") + } + + test("$.store.basket[*][*]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[*][*]")), + """[1,2,{"b":"y","a":"x"},3,4,5,6]""") + } + + test("$.store.basket[0][2].b") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[0][2].b")), + "y") + } + + test("$.store.basket[0][*].b") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[0][*].b")), + """["y"]""") + } + + test("$.zip code") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.zip code")), + "94025") + } + + test("$.fb:testid") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.fb:testid")), + "1234") + } + + test("preserve newlines") { + checkEvaluation( + GetJsonObject(Literal("""{"a":"b\nc"}"""), Literal("$.a")), + "b\nc") + } + + test("escape") { + checkEvaluation( + GetJsonObject(Literal("""{"a":"b\"c"}"""), Literal("$.a")), + "b\"c") + } + + test("$.non_exist_key") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.non_exist_key")), + null) + } + + test("$..no_recursive") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$..no_recursive")), + null) + } + + test("$.store.book[10]") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[10]")), + null) + } + + test("$.store.book[0].non_exist_key") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.book[0].non_exist_key")), + null) + } + + test("$.store.basket[*].non_exist_key") { + checkEvaluation( + GetJsonObject(Literal(json), Literal("$.store.basket[*].non_exist_key")), + null) + } + + test("non foldable literal") { + checkEvaluation( + GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), + "1234") + } + + val jsonTupleQuery = Literal("f1") :: + Literal("f2") :: + Literal("f3") :: + Literal("f4") :: + Literal("f5") :: + Nil + + private def checkJsonTuple(jt: JsonTuple, expected: InternalRow): Unit = { + assert(jt.eval(null).toSeq.head === expected) + } + + test("json_tuple - hive key 1") { + checkJsonTuple( + JsonTuple( + Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: + jsonTupleQuery), + InternalRow.fromSeq(Seq("value1", "value2", "3", null, "5.23").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 2") { + checkJsonTuple( + JsonTuple( + Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + jsonTupleQuery), + InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString))) + } + + test("json_tuple - hive key 2 (mix of foldable fields)") { + checkJsonTuple( + JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + Literal("f1") :: + NonFoldableLiteral("f2") :: + NonFoldableLiteral("f3") :: + Literal("f4") :: + Literal("f5") :: + Nil), + InternalRow.fromSeq(Seq("value12", "2", "value3", "4.01", null).map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3") { + checkJsonTuple( + JsonTuple( + Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: + jsonTupleQuery), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3 (nonfoldable json)") { + checkJsonTuple( + JsonTuple( + NonFoldableLiteral( + """{"f1": "value13", "f4": "value44", + | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) + :: jsonTupleQuery), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 3 (nonfoldable fields)") { + checkJsonTuple( + JsonTuple(Literal( + """{"f1": "value13", "f4": "value44", + | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: + NonFoldableLiteral("f1") :: + NonFoldableLiteral("f2") :: + NonFoldableLiteral("f3") :: + NonFoldableLiteral("f4") :: + NonFoldableLiteral("f5") :: + Nil), + InternalRow.fromSeq( + Seq("value13", "2", "value33", "value44", "5.01").map(UTF8String.fromString))) + } + + test("json_tuple - hive key 4 - null json") { + checkJsonTuple( + JsonTuple(Literal(null) :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - hive key 5 - null and empty fields") { + checkJsonTuple( + JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) + } + + test("json_tuple - hive key 6 - invalid json (array)") { + checkJsonTuple( + JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (object start only)") { + checkJsonTuple( + JsonTuple(Literal("{") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (no object end)") { + checkJsonTuple( + JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - invalid json (invalid json)") { + checkJsonTuple( + JsonTuple(Literal("\\") :: jsonTupleQuery), + InternalRow.fromSeq(Seq(null, null, null, null, null))) + } + + test("json_tuple - preserve newlines") { + checkJsonTuple( + JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), + InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f6404d21611e..7b85286c4dc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,15 +33,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null) checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, FloatType), null) - checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, DoubleType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) + checkEvaluation(Literal.create(null, DateType), null) + checkEvaluation(Literal.create(null, TimestampType), null) + checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } + test("default") { + checkEvaluation(Literal.default(BooleanType), false) + checkEvaluation(Literal.default(ByteType), 0.toByte) + checkEvaluation(Literal.default(ShortType), 0.toShort) + checkEvaluation(Literal.default(IntegerType), 0) + checkEvaluation(Literal.default(LongType), 0L) + checkEvaluation(Literal.default(FloatType), 0.0f) + checkEvaluation(Literal.default(DoubleType), 0.0) + checkEvaluation(Literal.default(StringType), "") + checkEvaluation(Literal.default(BinaryType), "".getBytes) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) + checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) + checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L)) + checkEvaluation(Literal.default(ArrayType(StringType)), Array()) + checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map()) + checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row("")) + } + test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) @@ -83,12 +109,14 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("decimal") { - List(0.0, 1.2, 1.1111, 5).foreach { d => + List(-0.0001, 0.0, 0.001, 1.2, 1.1111, 5).foreach { d => checkEvaluation(Literal(Decimal(d)), Decimal(d)) checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) - checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), - Decimal((d * 1000L).toLong, 10, 1)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala new file mode 100644 index 000000000000..d9c91415e249 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.Matchers +import org.scalatest.prop.GeneratorDrivenPropertyChecks + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Property is a high-level specification of behavior that should hold for a range of data points. + * + * For example, while we are evaluating a deterministic expression for some input, we should always + * hold the property that the result never changes, regardless of how we get the result, + * via interpreted or codegen. + * + * In ScalaTest, properties are specified as functions and the data points used to check properties + * can be supplied by either tables or generators. + * + * Generator-driven property checks are performed via integration with ScalaCheck. + * + * @example {{{ + * def toTest(i: Int): Boolean = if (i % 2 == 0) true else false + * + * import org.scalacheck.Gen + * + * test ("true if param is even") { + * val evenInts = for (n <- Gen.choose(-1000, 1000)) yield 2 * n + * forAll(evenInts) { (i: Int) => + * assert (toTest(i) === true) + * } + * } + * }}} + * + */ +object LiteralGenerator { + + lazy val byteLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) + + lazy val shortLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbShort.arbitrary } yield Literal.create(s, ShortType) + + lazy val integerLiteralGen: Gen[Literal] = + for { i <- Arbitrary.arbInt.arbitrary } yield Literal.create(i, IntegerType) + + lazy val longLiteralGen: Gen[Literal] = + for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType) + + lazy val floatLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2, + Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity) + } yield Literal.create(f, FloatType) + + lazy val doubleLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2, + Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) + } yield Literal.create(f, DoubleType) + + // TODO cache the generated data + def decimalLiteralGen(precision: Int, scale: Int): Gen[Literal] = { + assert(scale >= 0) + assert(precision >= scale) + Arbitrary.arbBigInt.arbitrary.map { s => + val a = (s % BigInt(10).pow(precision - scale)).toString() + val b = (s % BigInt(10).pow(scale)).abs.toString() + Literal.create( + Decimal(BigDecimal(s"$a.$b"), precision, scale), + DecimalType(precision, scale)) + } + } + + lazy val stringLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) + + lazy val binaryLiteralGen: Gen[Literal] = + for { ab <- Gen.listOf[Byte](Arbitrary.arbByte.arbitrary) } + yield Literal.create(ab.toArray, BinaryType) + + lazy val booleanLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) + + lazy val dateLiteralGen: Gen[Literal] = + for { d <- Arbitrary.arbInt.arbitrary } yield Literal.create(new Date(d), DateType) + + lazy val timestampLiteralGen: Gen[Literal] = + for { t <- Arbitrary.arbLong.arbitrary } yield Literal.create(new Timestamp(t), TimestampType) + + lazy val calendarIntervalLiterGen: Gen[Literal] = + for { m <- Arbitrary.arbInt.arbitrary; s <- Arbitrary.arbLong.arbitrary} + yield Literal.create(new CalendarInterval(m, s), CalendarIntervalType) + + + // Sometimes, it would be quite expensive when unlimited value is used, + // for example, the `times` arguments for StringRepeat would hang the test 'forever' + // if it's tested against Int.MaxValue by ScalaCheck, therefore, use values from a limited + // range is more reasonable + lazy val limitedIntegerLiteralGen: Gen[Literal] = + for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType) + + def randomGen(dt: DataType): Gen[Literal] = { + dt match { + case ByteType => byteLiteralGen + case ShortType => shortLiteralGen + case IntegerType => integerLiteralGen + case LongType => longLiteralGen + case DoubleType => doubleLiteralGen + case FloatType => floatLiteralGen + case DateType => dateLiteralGen + case TimestampType => timestampLiteralGen + case BooleanType => booleanLiteralGen + case StringType => stringLiteralGen + case BinaryType => binaryLiteralGen + case CalendarIntervalType => calendarIntervalLiterGen + case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale) + case dt => throw new IllegalArgumentException(s"not supported type $dt") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 649a5b44dc03..4ad65db0977c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,15 +20,17 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ - class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import IntegralLiteralTestUtils._ + /** * Used for testing leaf math expressions. * @@ -148,7 +150,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } @@ -181,60 +183,84 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("sin") { testUnary(Sin, math.sin) + checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) } test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) } test("sinh") { testUnary(Sinh, math.sinh) + checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) } test("cos") { testUnary(Cos, math.cos) + checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) } test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("cosh") { testUnary(Cosh, math.cosh) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) } test("tan") { testUnary(Tan, math.tan) + checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) } test("atan") { testUnary(Atan, math.atan) + checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) } test("tanh") { testUnary(Tanh, math.tanh) + checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) } test("toDegrees") { testUnary(ToDegrees, math.toDegrees) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("toRadians") { testUnary(ToRadians, math.toRadians) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) } test("cbrt") { testUnary(Cbrt, math.cbrt) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } test("ceil") { - testUnary(Ceil, math.ceil) + testUnary(Ceil, (d: Double) => math.ceil(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) + + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) } test("floor") { - testUnary(Floor, math.floor) + testUnary(Floor, (d: Double) => math.floor(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) + + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) } test("factorial") { @@ -244,37 +270,45 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) } test("rint") { testUnary(Rint, math.rint) + checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) } test("exp") { testUnary(Exp, math.exp) + checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) } test("expm1") { testUnary(Expm1, math.expm1) + checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) } test("signum") { testUnary[Double, Double](Signum, math.signum) + checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) } test("log") { testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) } test("log10") { testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) } test("log1p") { testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) } test("bin") { @@ -292,12 +326,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) + + checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) + checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) + + checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (1 to 20).map(_ * 0.1)) testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) } test("sqrt") { @@ -307,11 +347,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) } test("shift left") { @@ -323,6 +365,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + + checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt) + checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt) + checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt) + checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt) + checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt) + checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) + checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) + checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) } test("shift right") { @@ -334,6 +388,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + + checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt) + checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt) + checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt) + checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt) + checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt) + checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) + checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) + checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) } test("shift right unsigned") { @@ -345,6 +411,26 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) + + checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit), + positiveInt >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit), + positiveInt >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit), + negativeInt >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit), + negativeInt >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit), + positiveLong >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit), + positiveLong >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit), + negativeLong >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), + negativeLong >>> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) } test("hex") { @@ -359,6 +445,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on + Seq(LongType, BinaryType, StringType).foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) + } } test("unhex") { @@ -372,16 +461,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) checkEvaluation(Unhex(Literal("三重的")), null) - // scalastyle:on + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) } test("hypot") { testBinary(Hypot, math.hypot) + checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) } test("atan2") { testBinary(Atan2, math.atan2) + checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) } test("binary log") { @@ -413,6 +504,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } test("round") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index b524d0af14a6..75d17417e5a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,6 +29,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { @@ -37,6 +38,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { @@ -55,6 +57,6 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 31ecf4a9e810..118fd695fe2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.types._ * A literal value that is not foldable. Used in expression codegen testing to test code path * that behave differently based on foldable values. */ -case class NonFoldableLiteral(value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = false override def nullable: Boolean = true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index bf197124d8db..ace6c15dc841 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNullNans") { + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,46 +96,11 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) - } - - test("AtLeastNNull") { - val mix = Seq(Literal("x"), - Literal.create(null, StringType), - Literal.create(null, DoubleType), - Literal(Double.NaN), - Literal(5f)) - - val nanOnly = Seq(Literal("x"), - Literal(10.0), - Literal(Float.NaN), - Literal(math.log(-2)), - Literal(Double.MaxValue)) - - val nullOnly = Seq(Literal("x"), - Literal.create(null, DoubleType), - Literal.create(null, DecimalType.USER_DEFAULT), - Literal(Float.MaxValue), - Literal(false)) - - checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala new file mode 100644 index 000000000000..7ad8657bde12 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.math._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.types._ + +class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { + + def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = { + test(s"compare two arrays: a = $a, b = $b") { + val dataType = ArrayType(IntegerType) + val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] + val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => BoundReference(0, dataType, nullable = true).asc + case Descending => BoundReference(0, dataType, nullable = true).desc + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = GenerateOrdering.generate(sortOrder :: Nil) + Seq(intOrdering, genOrdering).foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) + } + } + } + } + + // Two arrays have the same size. + compareArrays(Seq[Any](), Seq[Any](), 0) + compareArrays(Seq[Any](1), Seq[Any](1), 0) + compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0) + compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1) + + // Two arrays have different sizes. + compareArrays(Seq[Any](), Seq[Any](1), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1) + + // Arrays having nulls. + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1) + compareArrays(Seq[Any](), Seq[Any](null), -1) + compareArrays(Seq[Any](null), Seq[Any](null), 0) + compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0) + compareArrays(Seq[Any](null), Seq[Any](null, null), -1) + compareArrays(Seq[Any](null), Seq[Any](1), -1) + compareArrays(Seq[Any](null), Seq[Any](null, 1), -1) + compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + { + val structType = + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true) + val arrayOfStructType = ArrayType(structType) + val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil + (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index d7eb13c50b13..03e7611fce8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -72,6 +72,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } + checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) + } + + test("AND, OR, EqualTo, EqualNullSafe consistency check") { + checkConsistencyBetweenInterpretedAndCodegen(And, BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Or, BooleanType, BooleanType) + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt) + } } booleanLogicTest("AND", And, @@ -108,6 +118,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))), + null) + checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) @@ -115,9 +131,38 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + val ns = Literal.create(null, StringType) + checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(ns, Seq(ns)), null) + checkEvaluation(In(Literal("a"), Seq(ns)), null) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = true).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal.create(_, t)) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(In(input(0), input.slice(1, 10)), expected) + } } test("INSET") { @@ -130,10 +175,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + checkEvaluation(InSet(three, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = true).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) + } } private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) @@ -145,6 +215,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + test("BinaryComparison consistency check") { + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(LessThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual, dt, dt) + } + } + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 4a644d136f09..b7a0d44fa7e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -24,12 +24,12 @@ import org.apache.spark.SparkFunSuite class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) - checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) + checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) + checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) } test("SPARK-9127 codegen with long seed") { - checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) - checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 906be701beed..99e3b13ce8c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SoundEx(Literal("!!")), "!!") } + test("translate") { + checkEvaluation( + StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae") + checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate") + checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae") + // test for multiple mapping + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd") + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd") + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b") + // scalastyle:on + } + test("TRIM/LTRIM/RTRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) @@ -659,7 +673,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) } - test("number format") { + test("format_number / FormatNumber") { checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") @@ -675,4 +689,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) } + + test("find in set") { + checkEvaluation( + FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")), null) + checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3) + checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) + checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 000000000000..a61297b2c039 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + val b3 = a.withQualifiers("qualifierName" :: Nil) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode != b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + assert(a != b3) + assert(a.hashCode != b3.hashCode) + assert(a.semanticEquals(b3)) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 59491c5ba160..68545f33e546 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +42,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val unsafeRow: UnsafeRow = converter.apply(row) - assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8)) + assert(unsafeRow.getSizeInBytes === 8 + (3 * 8)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -62,6 +61,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) assert(unsafeRowCopy.getInt(2) === 2) + + // Make sure the converter can be reused, i.e. we correctly reset all states. + val unsafeRow2: UnsafeRow = converter.apply(row) + assert(unsafeRow2.getSizeInBytes === 8 + (3 * 8)) + assert(unsafeRow2.getLong(0) === 0) + assert(unsafeRow2.getLong(1) === 1) + assert(unsafeRow2.getInt(2) === 2) } test("basic conversion with primitive, string and binary types") { @@ -123,7 +129,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType, - DecimalType.USER_DEFAULT + DecimalType.USER_DEFAULT, + DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -151,6 +158,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) assert(createdFromNull.getDecimal(10, 10, 0) === null) + assert(createdFromNull.getDecimal(11, 38, 18) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -169,11 +177,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) r.setDecimal(10, Decimal(10), 10) + r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) r } - // todo: we reuse the UnsafeRow in projection, so these tests are meaningless. val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -187,13 +195,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) - // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) for (i <- fieldTypes.indices) { - setToNullAfterCreation.setNullAt(i) + // Cann't call setNullAt() on DecimalType + if (i == 11) { + setToNullAfterCreation.setDecimal(11, null, 38) + } else { + setToNullAfterCreation.setNullAt(i) + } } - // There are some garbage left in the var-length area - assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) setToNullAfterCreation.setNullAt(0) setToNullAfterCreation.setBoolean(1, false) @@ -206,6 +218,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) setToNullAfterCreation.setDecimal(10, Decimal(10), 10) + setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -220,6 +233,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } @@ -238,107 +253,270 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } + test("basic conversion with struct type") { + val fieldTypes: Array[DataType] = Array( + new StructType().add("i", IntegerType), + new StructType().add("nest", new StructType().add("l", LongType)) + ) + + val converter = UnsafeProjection.create(fieldTypes) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, InternalRow(1)) + row.update(1, InternalRow(InternalRow(2L))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields == 2) + + val row1 = unsafeRow.getStruct(0, 1) + assert(row1.getSizeInBytes == 8 + 1 * 8) + assert(row1.numFields == 1) + assert(row1.getInt(0) == 1) + + val row2 = unsafeRow.getStruct(1, 1) + assert(row2.numFields() == 1) + + val innerRow = row2.getStruct(0, 1) + + { + assert(innerRow.getSizeInBytes == 8 + 1 * 8) + assert(innerRow.numFields == 1) + assert(innerRow.getLong(0) == 2L) + } + + assert(row2.getSizeInBytes == 8 + 1 * 8 + innerRow.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == 8 + 2 * 8 + row1.getSizeInBytes + row2.getSizeInBytes) + } + + private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + + private def createMap(keys: Any*)(values: Any*): MapData = { + assert(keys.length == values.length) + new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*)) + } + + private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { + assert(array.numElements == values.length) + assert(array.getSizeInBytes == 4 + (4 + 4) * values.length) + values.zipWithIndex.foreach { + case (value, index) => assert(array.getInt(index) == value) + } + } + + private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: Seq[Int]): Unit = { + assert(keys.length == values.length) + assert(map.numElements == keys.length) + + testArrayInt(map.keyArray, keys) + testArrayInt(map.valueArray, values) + + assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + } + test("basic conversion with array type") { val fieldTypes: Array[DataType] = Array( - ArrayType(LongType), - ArrayType(ArrayType(LongType)) + ArrayType(IntegerType), + ArrayType(ArrayType(IntegerType)) ) val converter = UnsafeProjection.create(fieldTypes) - val array1 = new GenericArrayData(Array[Any](1L, 2L)) - val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L)))) val row = new GenericMutableRow(fieldTypes.length) - row.update(0, array1) - row.update(1, array2) + row.update(0, createArray(1, 2)) + row.update(1, createArray(createArray(3, 4))) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.numFields() == 2) - val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] - assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2) - assert(unsafeArray1.numElements() == 2) - assert(unsafeArray1.getLong(0) == 1L) - assert(unsafeArray1.getLong(1) == 2L) + val unsafeArray1 = unsafeRow.getArray(0) + testArrayInt(unsafeArray1, Seq(1, 2)) - val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData] - assert(unsafeArray2.numElements() == 1) + val unsafeArray2 = unsafeRow.getArray(1) + assert(unsafeArray2.numElements == 1) - val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData] - assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2) - assert(nestedArray.numElements() == 2) - assert(nestedArray.getLong(0) == 3L) - assert(nestedArray.getLong(1) == 4L) + val nestedArray = unsafeArray2.getArray(0) + testArrayInt(nestedArray, Seq(3, 4)) assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) - val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes) - val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes) + val array1Size = roundedSize(unsafeArray1.getSizeInBytes) + val array2Size = roundedSize(unsafeArray2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } test("basic conversion with map type") { - def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + val fieldTypes: Array[DataType] = Array( + MapType(IntegerType, IntegerType), + MapType(IntegerType, MapType(IntegerType, IntegerType)) + ) + val converter = UnsafeProjection.create(fieldTypes) - def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = { - val numElements = keys.length - assert(map.numElements() == numElements) + val map1 = createMap(1, 2)(3, 4) - val keyArray = map.keys - assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements) - assert(keyArray.numElements() == numElements) - keys.zipWithIndex.foreach { case (key, i) => - assert(keyArray.getInt(i) == key) - } + val innerMap = createMap(5, 6)(7, 8) + val map2 = createMap(9)(innerMap) - val valueArray = map.values - assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements) - assert(valueArray.numElements() == numElements) - values.zipWithIndex.foreach { case (value, i) => - assert(valueArray.getLong(i) == value) - } + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, map1) + row.update(1, map2) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields == 2) + + val unsafeMap1 = unsafeRow.getMap(0) + testMapInt(unsafeMap1, Seq(1, 2), Seq(3, 4)) + + val unsafeMap2 = unsafeRow.getMap(1) + assert(unsafeMap2.numElements == 1) + + val keyArray = unsafeMap2.keyArray + testArrayInt(keyArray, Seq(9)) + + val valueArray = unsafeMap2.valueArray + + { + assert(valueArray.numElements == 1) + + val nestedMap = valueArray.getMap(0) + testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) + + assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes) + } + + assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) - assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + val map1Size = roundedSize(unsafeMap1.getSizeInBytes) + val map2Size = roundedSize(unsafeMap2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + } + + test("basic conversion with struct and array") { + val fieldTypes: Array[DataType] = Array( + new StructType().add("arr", ArrayType(IntegerType)), + ArrayType(new StructType().add("l", LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, InternalRow(createArray(1))) + row.update(1, createArray(InternalRow(2L))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val field1 = unsafeRow.getStruct(0, 1) + assert(field1.numFields == 1) + + val innerArray = field1.getArray(0) + testArrayInt(innerArray, Seq(1)) + + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes)) + + val field2 = unsafeRow.getArray(1) + assert(field2.numElements == 1) + + val innerStruct = field2.getStruct(0, 1) + + { + assert(innerStruct.numFields == 1) + assert(innerStruct.getSizeInBytes == 8 + 8) + assert(innerStruct.getLong(0) == 2L) } + assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) + } + + test("basic conversion with struct and map") { val fieldTypes: Array[DataType] = Array( - MapType(IntegerType, LongType), - MapType(IntegerType, MapType(IntegerType, LongType)) + new StructType().add("map", MapType(IntegerType, IntegerType)), + MapType(IntegerType, new StructType().add("l", LongType)) ) val converter = UnsafeProjection.create(fieldTypes) - val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L)) + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, InternalRow(createMap(1)(2))) + row.update(1, createMap(3)(InternalRow(4L))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val field1 = unsafeRow.getStruct(0, 1) + assert(field1.numFields == 1) - val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L)) - val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap)) + val innerMap = field1.getMap(0) + testMapInt(innerMap, Seq(1), Seq(2)) + + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes)) + + val field2 = unsafeRow.getMap(1) + + val keyArray = field2.keyArray + testArrayInt(keyArray, Seq(3)) + + val valueArray = field2.valueArray + + { + assert(valueArray.numElements == 1) + + val innerStruct = valueArray.getStruct(0, 1) + assert(innerStruct.numFields == 1) + assert(innerStruct.getSizeInBytes == 8 + 8) + assert(innerStruct.getLong(0) == 4L) + + assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + } + + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) + } + + test("basic conversion with array and map") { + val fieldTypes: Array[DataType] = Array( + ArrayType(MapType(IntegerType, IntegerType)), + MapType(IntegerType, ArrayType(IntegerType)) + ) + val converter = UnsafeProjection.create(fieldTypes) val row = new GenericMutableRow(fieldTypes.length) - row.update(0, map1) - row.update(1, map2) + row.update(0, createArray(createMap(1)(2))) + row.update(1, createMap(3)(createArray(4))) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.numFields() == 2) - val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData] - testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L)) + val field1 = unsafeRow.getArray(0) + assert(field1.numElements == 1) - val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData] - assert(unsafeMap2.numElements() == 1) + val innerMap = field1.getMap(0) + testMapInt(innerMap, Seq(1), Seq(2)) - val keyArray = unsafeMap2.keys - assert(keyArray.getSizeInBytes == 4 + 4) - assert(keyArray.numElements() == 1) - assert(keyArray.getInt(0) == 9) + assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes) - val valueArray = unsafeMap2.values - assert(valueArray.numElements() == 1) - val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData] - testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L)) - assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes) + val field2 = unsafeRow.getMap(1) + assert(field2.numElements == 1) - assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + val keyArray = field2.keyArray + testArrayInt(keyArray, Seq(3)) - val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes) - val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes) - assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + val valueArray = field2.valueArray + + { + assert(valueArray.numElements == 1) + + val innerArray = valueArray.getArray(0) + testArrayInt(innerArray, Seq(4)) + + assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) + } + + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala new file mode 100644 index 000000000000..0d329497758c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, BoundReference} +import org.apache.spark.sql.types.{DataType, IntegerType} + +import scala.collection.mutable +import org.scalatest.Assertions._ + +class HyperLogLogPlusPlusSuite extends SparkFunSuite { + + /** Create a HLL++ instance and an input and output buffer. */ + def createEstimator(rsd: Double, dt: DataType = IntegerType): + (HyperLogLogPlusPlus, MutableRow, MutableRow) = { + val input = new SpecificMutableRow(Seq(dt)) + val hll = new HyperLogLogPlusPlus(new BoundReference(0, dt, true), rsd) + val buffer = createBuffer(hll) + (hll, input, buffer) + } + + def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = { + val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType)) + hll.initialize(buffer) + buffer + } + + /** Evaluate the estimate. It should be within 3*SD's of the given true rsd. */ + def evaluateEstimate(hll: HyperLogLogPlusPlus, buffer: MutableRow, cardinality: Int): Unit = { + val estimate = hll.eval(buffer).asInstanceOf[Long].toDouble + val error = math.abs((estimate / cardinality.toDouble) - 1.0d) + assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.") + } + + test("add nulls") { + val (hll, input, buffer) = createEstimator(0.05) + input.setNullAt(0) + hll.update(buffer, input) + hll.update(buffer, input) + val estimate = hll.eval(buffer).asInstanceOf[Long] + assert(estimate == 0L, "Nothing meaningful added; estimate should be 0.") + } + + def testCardinalityEstimates( + rsds: Seq[Double], + ns: Seq[Int], + f: Int => Int, + c: Int => Int): Unit = { + rsds.flatMap(rsd => ns.map(n => (rsd, n))).foreach { + case (rsd, n) => + val (hll, input, buffer) = createEstimator(rsd) + var i = 0 + while (i < n) { + input.setInt(0, f(i)) + hll.update(buffer, input) + i += 1 + } + val estimate = hll.eval(buffer).asInstanceOf[Long].toDouble + val cardinality = c(n) + val error = math.abs((estimate / cardinality.toDouble) - 1.0d) + assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.") + } + } + + test("deterministic cardinality estimation") { + val repeats = 10 + testCardinalityEstimates( + Seq(0.1, 0.05, 0.025, 0.01), + Seq(100, 500, 1000, 5000, 10000, 50000, 100000, 500000, 1000000).map(_ * repeats), + i => i / repeats, + i => i / repeats) + } + + test("random cardinality estimation") { + val srng = new Random(323981238L) + val seen = mutable.HashSet.empty[Int] + val update = (i: Int) => { + val value = srng.nextInt() + seen += value + value + } + val eval = (n: Int) => { + val cardinality = seen.size + seen.clear() + cardinality + } + testCardinalityEstimates( + Seq(0.05, 0.01), + Seq(100, 10000, 500000), + update, + eval) + } + + // Test merging + test("merging HLL instances") { + val (hll, input, buffer1a) = createEstimator(0.05) + val buffer1b = createBuffer(hll) + val buffer2 = createBuffer(hll) + + // Create the + // Add the lower half + var i = 0 + while (i < 500000) { + input.setInt(0, i) + hll.update(buffer1a, input) + i += 1 + } + + // Add the upper half + i = 500000 + while (i < 1000000) { + input.setInt(0, i) + hll.update(buffer1b, input) + i += 1 + } + + // Merge the lower and upper halfs. + hll.merge(buffer1a, buffer1b) + + // Create the other buffer in reverse + i = 999999 + while (i >= 0) { + input.setInt(0, i) + hll.update(buffer2, input) + i -= 1 + } + + // Check if the buffers are equal. + assert(buffer2 == buffer1a, "Buffers should be equal") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 46daa3eb8bf8..9da1068e9ca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -29,78 +29,68 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("basic example") { - """ - |class A { + """class A { |blahblah; - |} - """.stripMargin + |}""".stripMargin }{ """ - |class A { - | blahblah; - |} + |/* 001 */ class A { + |/* 002 */ blahblah; + |/* 003 */ } """.stripMargin } testCase("nested example") { - """ - |class A { + """class A { | if (c) { |duh; |} - |} - """.stripMargin + |}""".stripMargin } { """ - |class A { - | if (c) { - | duh; - | } - |} + |/* 001 */ class A { + |/* 002 */ if (c) { + |/* 003 */ duh; + |/* 004 */ } + |/* 005 */ } """.stripMargin } testCase("single line") { - """ - |class A { + """class A { | if (c) {duh;} - |} - """.stripMargin + |}""".stripMargin }{ """ - |class A { - | if (c) {duh;} - |} + |/* 001 */ class A { + |/* 002 */ if (c) {duh;} + |/* 003 */ } """.stripMargin } testCase("if else on the same line") { - """ - |class A { + """class A { | if (c) {duh;} else {boo;} - |} - """.stripMargin + |}""".stripMargin }{ """ - |class A { - | if (c) {duh;} else {boo;} - |} + |/* 001 */ class A { + |/* 002 */ if (c) {duh;} else {boo;} + |/* 003 */ } """.stripMargin } testCase("function calls") { - """ - |foo( + """foo( |a, |b, - |c) - """.stripMargin + |c)""".stripMargin }{ """ - |foo( - | a, - | b, - | c) + |/* 001 */ foo( + |/* 002 */ a, + |/* 003 */ b, + |/* 004 */ c) """.stripMargin } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 2d3f98dbbd3d..c9616cdb26c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -34,12 +34,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance.apply(null).getBoolean(0) === false) } - test("GenerateProjection should initialize expressions") { - val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateProjection.generate(Seq(expr)) - assert(instance.apply(null).getBoolean(0) === false) - } - test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr))() @@ -64,18 +58,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection should not share expression instances") { - val expr1 = MutableExpression() - val instance1 = GenerateProjection.generate(Seq(expr1)) - assert(instance1.apply(null).getBoolean(0) === false) - - val expr2 = MutableExpression() - expr2.mutableState = true - val instance2 = GenerateProjection.generate(Seq(expr2)) - assert(instance1.apply(null).getBoolean(0) === false) - assert(instance2.apply(null).getBoolean(0) === true) - } - test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index aff1bee99faa..796d60032e1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * A test suite for the bitset portion of the row concatenation. @@ -96,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala new file mode 100644 index 000000000000..1522ee34e43a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A test suite for generated projections + */ +class GeneratedProjectionSuite extends SparkFunSuite { + + test("generated projections on wider table") { + val N = 1000 + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs)() + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + + test("generated unsafe projection with array of binary") { + val row = InternalRow( + Array[Byte](1, 2), + new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) + val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] + + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow: UnsafeRow = unsafeProj(row) + assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) + assert(unsafeRow.getArray(1).isNullAt(1)) + assert(unsafeRow.getArray(1).getBinary(1) === null) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) + + val safeProj = FromUnsafeProjection(fields) + val row2 = safeProj(unsafeRow) + assert(row2 === row) + } + + test("padding bytes should be zeroed out") { + val types = Seq(BooleanType, ByteType, ShortType, IntegerType, FloatType, BinaryType, + StringType) + val struct = StructType(types.map(StructField("", _, true))) + val fields = Array[DataType](StringType, struct) + val unsafeProj = UnsafeProjection.create(fields) + + val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes, + UTF8String.fromString("")) + val row1 = InternalRow(UTF8String.fromString(""), innerRow) + val unsafe1 = unsafeProj(row1).copy() + // create a Row with long String before the inner struct + val row2 = InternalRow(UTF8String.fromString("a_long_string").repeat(10), innerRow) + val unsafe2 = unsafeProj(row2).copy() + assert(unsafe1.getStruct(1, 7) === unsafe2.getStruct(1, 7)) + val unsafe3 = unsafeProj(row1).copy() + assert(unsafe1 === unsafe3) + assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index d4916ea8d273..cde346e99eb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -88,20 +89,40 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } - private def caseInsensitiveAnalyse(plan: LogicalPlan) = - AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) + test("a && (!a || b)") { + checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + + checkCondition(('a && ('b || !('a) )), ('a && 'b)) + + checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + + checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + } + + test("!(a && b) , !(a || b)") { + checkCondition((!('a && 'b)), (!('a) || !('b))) + + checkCondition(!('a || 'b), (!('a) && !('b))) + } + + private val caseInsensitiveAnalyzer = + new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 && ('b > 3 || 'b < 5))) comparePlans(actual, expected) } test("(a || b) && (a || c) => a || (b && c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala new file mode 100644 index 000000000000..9bf61ae09178 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types.StringType + +class ColumnPruningSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Column pruning", FixedPoint(100), + ColumnPruning) :: Nil + } + + test("Column pruning for Generate when Generate.join = false") { + val input = LocalRelation('a.int, 'b.array(StringType)) + + val query = input.generate(Explode('b), join = false).analyze + + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning for Generate when Generate.join = true") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + + val query = + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + + val optimized = Optimize.execute(query) + + val correctAnswer = + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Turn Generate.join to false if possible") { + val input = LocalRelation('b.array(StringType)) + + val query = + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + + val optimized = Optimize.execute(query) + + val correctAnswer = + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning for Project on Sort") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + + val query = input.orderBy('b.asc).select('a).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + + comparePlans(optimized, correctAnswer) + } + + // todo: add more tests for column pruning +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ec3b2f1edfa0..8aaefa84937c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) @@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: Fold In(v, list) into true or false") { - var originalQuery = + val originalQuery = testRelation .select('a) .where(In(Literal(1), Seq(Literal(1), Literal(2)))) - var optimized = Optimize.execute(originalQuery.analyze) - - var correctAnswer = - testRelation - .select('a) - .where(Literal(true)) - .analyze - - comparePlans(optimized, correctAnswer) - - originalQuery = - testRelation - .select('a) - .where(In(Literal(1), Seq(Literal(1), 'a.attr))) - - optimized = Optimize.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - correctAnswer = + val correctAnswer = testRelation .select('a) .where(Literal(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0f1fde2fb0f6..fba4c5ca77d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -40,6 +40,7 @@ class FilterPushdownSuite extends PlanTest { BooleanSimplification, PushPredicateThroughJoin, PushPredicateThroughGenerate, + PushPredicateThroughAggregate, ColumnPruning, ProjectCollapsing) :: Nil } @@ -67,7 +68,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group") { val originalQuery = testRelation - .groupBy('a)('a, Count('b)) + .groupBy('a)('a, count('b)) .select('a) val optimized = Optimize.execute(originalQuery.analyze) @@ -83,7 +84,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group with alias") { val originalQuery = testRelation - .groupBy('a)('a as 'c, Count('b)) + .groupBy('a)('a as 'c, count('b)) .select('c) val optimized = Optimize.execute(originalQuery.analyze) @@ -652,4 +653,101 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer.analyze) } + + test("aggregate: push down filter when filter on group by expression") { + val originalQuery = testRelation + .groupBy('a)('a, count('b) as 'c) + .select('a, 'c) + .where('a === 2) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a === 2) + .groupBy('a)('a, count('b) as 'c) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filter when filter not on group by expression") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c) + .where('c === 2L) + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } + + test("aggregate: push down filters partially which are subset of group by expressions") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c) + .where('c === 2L && 'a === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a === 3) + .groupBy('a)('a, count('b) as 'c) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where(('c === 2L || 'aa > 4) && 'aa < 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a + 1 < 3) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L || 'aa > 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where("s" === "s") + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters that are nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala new file mode 100644 index 000000000000..9b1e16c72764 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class JoinOrderSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing) :: Nil + + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some(Seq(x, y), Seq())) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq())) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + } + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + x.join(y).join(z) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 1d433275fed2..48cab01ac100 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -35,7 +35,8 @@ class OptimizeInSuite extends PlanTest { val batches = Batch("AnalysisNodes", Once, EliminateSubQueries) :: - Batch("ConstantFolding", Once, + Batch("ConstantFolding", FixedPoint(10), + NullPropagation, ConstantFolding, BooleanSimplification, OptimizeIn) :: Nil @@ -43,16 +44,26 @@ class OptimizeInSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - test("OptimizedIn test: In clause optimized to InSet") { + test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery) + } + + test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .analyze + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) + .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) @@ -72,4 +83,52 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { + val originalQuery = + testRelation + .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: Inset optimization disabled as " + + "list expression contains attribute)") { + val originalQuery = + testRelation + .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: Inset optimization disabled as " + + "list expression contains attribute - select)") { + val originalQuery = + testRelation + .select(In(Literal.create(null, StringType), + Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Literal.create(null, BooleanType).as("a")) + .analyze + + comparePlans(optimized, correctAnswer) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72..1595ad932742 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -60,23 +61,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala new file mode 100644 index 000000000000..455a3810c719 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ + +/** + * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly + * skips sub-trees that have already been marked as analyzed. + */ +class LogicalPlanSuite extends SparkFunSuite { + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val testRelation = LocalRelation() + + test("resolveOperator runs on operators") { + invocationCount = 0 + val plan = Project(Nil, testRelation) + plan resolveOperators function + + assert(invocationCount === 1) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan resolveOperators function + + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan.foreach(_.setAnalyzed()) + plan resolveOperators function + + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, testRelation) + val plan2 = Project(Nil, plan1) + plan1.foreach(_.setAnalyzed()) + plan2 resolveOperators function + + assert(invocationCount === 1) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 765c1e2dda99..2efee1fc5470 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends SparkFunSuite { - +abstract class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b34..965bdb1515e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala similarity index 93% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1ba290753ce4..bebf70896547 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class DataTypeParserSuite extends SparkFunSuite { @@ -48,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -82,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -90,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index d18fa4df1335..0ce5a2fb6950 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -49,13 +49,17 @@ class DateTimeUtilsSuite extends SparkFunSuite { test("us and julian day") { val (d, ns) = toJulianDay(0) assert(d === JULIAN_DAY_OF_EPOCH) - assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) - val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t2)) + Seq(Timestamp.valueOf("2015-06-11 10:10:10.100"), + Timestamp.valueOf("2015-06-11 20:10:10.100"), + Timestamp.valueOf("1900-06-11 20:10:10.100")).foreach { t => + val (d, ns) = toJulianDay(fromJavaTimestamp(t)) + assert(ns > 0) + val t1 = toJavaTimestamp(fromJulianDay(d, ns)) + assert(t.equals(t1)) + } } test("SPARK-6785: java date conversion before and after epoch") { @@ -106,6 +110,10 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) + c.set(1, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(stringToDate(UTF8String.fromString("0001")).get === + millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) @@ -130,6 +138,42 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015")).isEmpty) + } + + test("string to time") { + // Tests with UTC. + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-00:00") === c.getTime()) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30T10:00:00Z") === c.getTime()) + + // Tests with set time zone. + c.setTimeZone(TimeZone.getTimeZone("GMT-04:00")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00-04:00") === c.getTime()) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-04:00") === c.getTime()) + + // Tests with local time zone. + c.setTimeZone(TimeZone.getDefault()) + c.set(Calendar.MILLISECOND, 0) + + c.set(2000, 11, 30, 0, 0, 0) + assert(stringToTime("2000-12-30") === new Date(c.getTimeInMillis())) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30 10:00:00") === new Timestamp(c.getTimeInMillis())) } test("string to timestamp") { @@ -138,9 +182,9 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) - c.set(2015, 0, 1, 0, 0, 0) + c.set(1, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("0001")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) @@ -283,6 +327,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("00238")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) @@ -290,12 +335,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("02015-01-18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("015-01-18")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { @@ -322,6 +377,19 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } + test("hours / minutes / seconds") { + Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), + Timestamp.valueOf("2015-06-11 20:13:40.789"), + Timestamp.valueOf("1900-06-11 12:14:50.789"), + Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => + val us = fromJavaTimestamp(t) + assert(toJavaTimestamp(us) === t) + assert(getHours(us) === t.getHours) + assert(getMinutes(us) === t.getMinutes) + assert(getSeconds(us) === t.getSeconds) + } + } + test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d7..706ecd29d135 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 417df006ab7c..ed2c641d63e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -46,6 +46,25 @@ object DataTypeTestUtils { */ val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + // TODO: remove this once we find out how to handle decimal properly in property check + val numericTypeWithoutDecimal: Set[DataType] = integralType ++ Set(DoubleType, FloatType) + + /** + * Instances of all [[NumericType]]s and [[CalendarIntervalType]] + */ + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + + /** + * All the types that support ordering + */ + val ordered: Set[DataType] = + numericTypeWithoutDecimal + BooleanType + TimestampType + DateType + StringType + BinaryType + + /** + * All the types that we can use in a property check + */ + val propertyCheckSupported: Set[DataType] = ordered + /** * Instances of all [[AtomicType]]s. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala similarity index 89% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 1d297beb3868..50683947da22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.types.decimal +package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.Decimal import org.scalatest.PrivateMethodTester import scala.language.postfixOps @@ -44,6 +43,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2) checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) intercept[IllegalArgumentException](Decimal(170L, 2, 1)) @@ -166,6 +166,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(100) % Decimal(0) === null) } + // regression test for SPARK-8359 + test("accurate precision after multiplication") { + val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal + assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") + } + + // regression test for SPARK-8677 + test("fix non-terminating decimal expansion problem") { + val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) + // The difference between decimal should not be more than 0.001. + assert(decimal.toDouble - 0.333 < 0.001) + } + + // regression test for SPARK-8800 + test("fix loss of precision/scale when doing division operation") { + val a = Decimal(2) / Decimal(3) + assert(a.toDouble < 1.0 && a.toDouble > 0.6) + val b = Decimal(1) / Decimal(8) + assert(b.toDouble === 0.125) + } + test("set/setOrNull") { assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L) assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f63..06841b094562 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -60,6 +60,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.apache.parquet parquet-column @@ -73,11 +77,6 @@ jackson-databind ${fasterxml.jackson.version} - - junit - junit - test - org.scalacheck scalacheck_${scala.binary.version} @@ -92,13 +91,11 @@ mysql mysql-connector-java - 5.1.34 test org.postgresql postgresql - 9.3-1102-jdbc41 test @@ -111,6 +108,11 @@ mockito-core test + + org.apache.xbean + xbean-asm5-shaded + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 9e2c9334a7be..a2f99d566d47 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,17 +20,16 @@ import java.io.IOException; import org.apache.spark.SparkEnv; -import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -72,7 +71,7 @@ public final class UnsafeFixedWidthAggregationMap { */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.isFixedLength(field.dataType())) { + if (!UnsafeRow.isMutable(field.dataType())) { return false; } } @@ -86,8 +85,6 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. - * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with - * other tasks. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) @@ -97,22 +94,19 @@ public UnsafeFixedWidthAggregationMap( StructType aggregationBufferSchema, StructType groupingKeySchema, TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); + this.map = + new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); - assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + - UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** @@ -123,6 +117,10 @@ public UnsafeFixedWidthAggregationMap( public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( unsafeGroupingKeyRow.getBaseObject(), @@ -136,7 +134,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { unsafeGroupingKeyRow.getBaseOffset(), unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length ); if (!putSucceeded) { @@ -156,14 +154,17 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { } /** - * Returns an iterator over the keys and values in this map. + * Returns an iterator over the keys and values in this map. This uses destructive iterator of + * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has + * been called. * * For efficiency, each call returns the same object. */ public KVIterator iterator() { return new KVIterator() { - private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator(); + private final BytesToBytesMap.MapIterator mapLocationIterator = + map.destructiveIterator(); private final UnsafeRow key = new UnsafeRow(); private final UnsafeRow value = new UnsafeRow(); @@ -208,6 +209,13 @@ public void close() { }; } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + return map.getPeakMemoryUsedBytes(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ @@ -228,16 +236,13 @@ public void printPerfMetrics() { /** * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] - * that can be used to insert more records to do external sorting. - * - * The only memory that is allocated is the address/prefix array, 16 bytes per record. * - * Note that this destroys the map, and as a result, the map cannot be used anymore after this. + * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used + * to insert records. */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { - UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + return new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map); - return sorter; + SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index f6b017686306..8c9b9c85e37f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -17,24 +17,22 @@ package org.apache.spark.sql.execution; -import java.io.IOException; - import javax.annotation.Nullable; +import java.io.IOException; import com.google.common.annotations.VisibleForTesting; import org.apache.spark.TaskContext; -import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.collection.unsafe.sort.*; /** @@ -50,14 +48,19 @@ public final class UnsafeKVExternalSorter { private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes) - throws IOException { - this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null); + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes) throws IOException { + this(keySchema, valueSchema, blockManager, pageSizeBytes, null); } - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes, + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -73,7 +76,6 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, if (map == null) { sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, @@ -81,12 +83,17 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, /* initialSize */ 4096, pageSizeBytes); } else { - // Insert the records into the in-memory sorter. + // During spilling, the array in map will not be used, so we can borrow that and use it + // as the underline array for in-memory sorter (it's always large enough). + // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.numElements()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + // We cannot use the destructive iterator here because we are reusing the existing memory + // pages in BytesToBytesMap to hold records during sorting. + // The only new memory we are allocating is the pointer/prefix array. + BytesToBytesMap.MapIterator iter = map.iterator(); final int numKeyFields = keySchema.size(); - BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); UnsafeRow row = new UnsafeRow(); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); @@ -107,8 +114,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, } sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( - taskContext.taskMemoryManager(), - shuffleMemoryManager, + taskMemoryManager, blockManager, taskContext, new KVComparator(ordering, keySchema.length()), @@ -117,8 +123,9 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); - map.free(); + // reset the map, so we can re-use it to insert new records. the inMemSorter will not used + // anymore, so the underline array could be used by map again. + map.reset(); } } @@ -134,7 +141,20 @@ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } - public KVIterator sortedIterator() throws IOException { + /** + * Merges another UnsafeKVExternalSorter into `this`, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeKVExternalSorter other) throws IOException { + sorter.merge(other.sorter); + } + + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public KVSorterIterator sortedIterator() throws IOException { try { final UnsafeSorterIterator underlying = sorter.getSortedIterator(); if (!underlying.hasNext()) { @@ -142,64 +162,20 @@ public KVIterator sortedIterator() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - - return new KVIterator() { - private UnsafeRow key = new UnsafeRow(); - private UnsafeRow value = new UnsafeRow(); - private int numKeyFields = keySchema.size(); - private int numValueFields = valueSchema.size(); - - @Override - public boolean next() throws IOException { - try { - if (underlying.hasNext()) { - underlying.loadNext(); - - Object baseObj = underlying.getBaseObject(); - long recordOffset = underlying.getBaseOffset(); - int recordLen = underlying.getRecordLength(); - - // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); - int valueLen = recordLen - keyLen - 4; - - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); - - return true; - } else { - key = null; - value = null; - cleanupResources(); - return false; - } - } catch (IOException e) { - cleanupResources(); - throw e; - } - } - - @Override - public UnsafeRow getKey() { - return key; - } - - @Override - public UnsafeRow getValue() { - return value; - } - - @Override - public void close() { - cleanupResources(); - } - }; + return new KVSorterIterator(underlying); } catch (IOException e) { cleanupResources(); throw e; } } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + return sorter.getPeakMemoryUsedBytes(); + } + /** * Marks the current page as no-more-space-available, and as a result, either allocate a * new page or spill when we see the next record. @@ -209,8 +185,11 @@ void closeCurrentPage() { sorter.closeCurrentPage(); } - private void cleanupResources() { - sorter.freeMemory(); + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + */ + public void cleanupResources() { + sorter.cleanupResources(); } private static final class KVComparator extends RecordComparator { @@ -233,4 +212,60 @@ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff return ordering.compare(row1, row2); } } + + public class KVSorterIterator extends KVIterator { + private UnsafeRow key = new UnsafeRow(); + private UnsafeRow value = new UnsafeRow(); + private final int numKeyFields = keySchema.size(); + private final int numValueFields = valueSchema.size(); + private final UnsafeSorterIterator underlying; + + private KVSorterIterator(UnsafeSorterIterator underlying) { + this.underlying = underlying; + } + + @Override + public boolean next() throws IOException { + try { + if (underlying.hasNext()) { + underlying.loadNext(); + + Object baseObj = underlying.getBaseObject(); + long recordOffset = underlying.getBaseOffset(); + int recordLen = underlying.getRecordLength(); + + // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) + int keyLen = Platform.getInt(baseObj, recordOffset); + int valueLen = recordLen - keyLen - 4; + key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + + return true; + } else { + key = null; + value = null; + cleanupResources(); + return false; + } + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + cleanupResources(); + } + }; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 000000000000..842dcb8c93dc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 000000000000..0cc4566c9cdd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Tries to initialize the reader for this split. Returns true if this reader supports reading + * this split and false otherwise. + */ + public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) { + try { + initialize(inputSplit, taskAttemptContext); + return true; + } catch (Exception e) { + return false; + } + } + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.remaining(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = + UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len); + } + rows[n].setNotNullAt(col); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 000000000000..507100be9096 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..1ca2044057e5 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,4 @@ +org.apache.spark.sql.execution.datasources.jdbc.DefaultSource +org.apache.spark.sql.execution.datasources.json.DefaultSource +org.apache.spark.sql.execution.datasources.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.text.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css new file mode 100644 index 000000000000..ddd3a91dd8ef --- /dev/null +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#plan-viz-graph .label { + font-weight: normal; + text-shadow: none; +} + +#plan-viz-graph svg g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + +/* Hightlight the SparkPlan node name */ +#plan-viz-graph svg text :first-child { + font-weight: bold; +} + +#plan-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; +} diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js new file mode 100644 index 000000000000..5161fcde669e --- /dev/null +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +var PlanVizConstants = { + svgMarginX: 16, + svgMarginY: 16 +}; + +function renderPlanViz() { + var svg = planVizContainer().append("svg"); + var metadata = d3.select("#plan-viz-metadata"); + var dot = metadata.select(".dot-file").text().trim(); + var graph = svg.append("g"); + + var g = graphlibDot.read(dot); + preprocessGraphLayout(g); + var renderer = new dagreD3.render(); + renderer(graph, g); + + // Round corners on rectangles + svg + .selectAll("rect") + .attr("rx", "5") + .attr("ry", "5"); + + var nodeSize = parseInt($("#plan-viz-metadata-size").text()); + for (var i = 0; i < nodeSize; i++) { + setupTooltipForSparkPlanNode(i); + } + + resizeSvg(svg) +} + +/* -------------------- * + * | Helper functions | * + * -------------------- */ + +function planVizContainer() { return d3.select("#plan-viz-graph"); } + +/* + * Set up the tooltip for a SparkPlan node using metadata. When the user moves the mouse on the + * node, it will display the details of this SparkPlan node in the right. + */ +function setupTooltipForSparkPlanNode(nodeId) { + var nodeTooltip = d3.select("#plan-meta-data-" + nodeId).text() + d3.select("svg g .node_" + nodeId) + .on('mouseover', function(d) { + var domNode = d3.select(this).node(); + $(domNode).tooltip({ + title: nodeTooltip, trigger: "manual", container: "body", placement: "right" + }); + $(domNode).tooltip("show"); + }) + .on('mouseout', function(d) { + var domNode = d3.select(this).node(); + $(domNode).tooltip("destroy"); + }) +} + +/* + * Helper function to pre-process the graph layout. + * This step is necessary for certain styles that affect the positioning + * and sizes of graph elements, e.g. padding, font style, shape. + */ +function preprocessGraphLayout(g) { + var nodes = g.nodes(); + for (var i = 0; i < nodes.length; i++) { + var node = g.node(nodes[i]); + node.padding = "5"; + } + // Curve the edges + var edges = g.edges(); + for (var j = 0; j < edges.length; j++) { + var edge = g.edge(edges[j]); + edge.lineInterpolate = "basis"; + } +} + +/* + * Helper function to size the SVG appropriately such that all elements are displayed. + * This assumes that all outermost elements are clusters (rectangles). + */ +function resizeSvg(svg) { + var allClusters = svg.selectAll("g rect")[0]; + console.log(allClusters); + var startX = -PlanVizConstants.svgMarginX + + toFloat(d3.min(allClusters, function(e) { + console.log(e); + return getAbsolutePosition(d3.select(e)).x; + })); + var startY = -PlanVizConstants.svgMarginY + + toFloat(d3.min(allClusters, function(e) { + return getAbsolutePosition(d3.select(e)).y; + })); + var endX = PlanVizConstants.svgMarginX + + toFloat(d3.max(allClusters, function(e) { + var t = d3.select(e); + return getAbsolutePosition(t).x + toFloat(t.attr("width")); + })); + var endY = PlanVizConstants.svgMarginY + + toFloat(d3.max(allClusters, function(e) { + var t = d3.select(e); + return getAbsolutePosition(t).y + toFloat(t.attr("height")); + })); + var width = endX - startX; + var height = endY - startY; + svg.attr("viewBox", startX + " " + startY + " " + width + " " + height) + .attr("width", width) + .attr("height", height); +} + +/* Helper function to convert attributes to numeric values. */ +function toFloat(f) { + if (f) { + return parseFloat(f.toString().replace(/px$/, "")); + } else { + return f; + } +} + +/* + * Helper function to compute the absolute position of the specified element in our graph. + */ +function getAbsolutePosition(d3selection) { + if (d3selection.empty()) { + throw "Attempted to get absolute position of an empty selection."; + } + var obj = d3selection; + var _x = toFloat(obj.attr("x")) || 0; + var _y = toFloat(obj.attr("y")) || 0; + while (!obj.empty()) { + var transformText = obj.attr("transform"); + if (transformText) { + var translate = d3.transform(transformText).translate; + _x += toFloat(translate[0]); + _y += toFloat(translate[1]); + } + // Climb upwards to find how our parents are translated + obj = d3.select(obj.node().parentNode); + // Stop when we've reached the graph container itself + if (obj.node() == planVizContainer().node()) { + break; + } + } + return { x: _x, y: _y }; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b25dcbca82b9..297ef2299cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -36,15 +40,64 @@ private[sql] object Column { def unapply(col: Column): Option[Expression] = Some(col.expr) } +/** + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. + * + * @since 1.6.0 + */ +class TypedColumn[-T, U]( + expr: Expression, + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { + + /** + * Inserts the specific input type and schema into any expressions that are expected to operate + * on a decoded object. + */ + private[sql] def withInputType( + inputEncoder: ExpressionEncoder[_], + schema: Seq[Attribute]): TypedColumn[T, U] = { + val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] + new TypedColumn[T, U]( + expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy(aEncoder = Some(boundEncoder), children = schema) + }, + encoder) + } +} /** * :: Experimental :: - * A column in a [[DataFrame]]. + * A column that will be computed based on the data in a [[DataFrame]]. + * + * A new column is constructed based on the input columns present in a dataframe: + * + * {{{ + * df("columnName") // On a specific DataFrame. + * col("columnName") // A generic column no yet associcated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * expr("a + 1") // A column that is constructed from a parsed SQL Expression. + * lit("abc") // A column that produces a literal (constant) value. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: * - * @groupname java_expr_ops Java-specific expression operators. - * @groupname expr_ops Expression operators. - * @groupname df_ops DataFrame functions. - * @groupname Ungrouped Support functions for DataFrames. + * {{{ + * $"a" + 1 + * $"a" === $"b" + * }}} + * + * @groupname java_expr_ops Java-specific expression operators + * @groupname expr_ops Expression operators + * @groupname df_ops DataFrame functions + * @groupname Ungrouped Support functions for DataFrames * * @since 1.3.0 */ @@ -53,12 +106,34 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) - case _ => UnresolvedAttribute(name) + case _ if name.endsWith(".*") => { + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + } + case _ => UnresolvedAttribute.quotedString(name) }) /** Creates a column based on the given expression. */ - implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: Expression): Column = new Column(newExpr) + + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + private[sql] def named: NamedExpression = expr match { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + + case expr: NamedExpression => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + + case expr: Expression => Alias(expr, expr.prettyString)() + } override def toString: String = expr.prettyString @@ -69,19 +144,30 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode + /** + * Provides a type hint about the expected return value of this column. This information can + * be used by operations such as `select` on a [[Dataset]] to automatically convert the + * results into the correct JVM types. + * @since 1.6.0 + */ + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) + /** * Extracts a value or values from a complex type. * The following types of extraction are supported: - * - Given an Array, an integer ordinal can be used to retrieve a single value. - * - Given a Map, a key of the correct type can be used to retrieve an individual value. - * - Given a Struct, a string fieldName can be used to extract that field. - * - Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields + * + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields * * @group expr_ops * @since 1.4.0 */ - def apply(extraction: Any): Column = UnresolvedExtractValue(expr, lit(extraction).expr) + def apply(extraction: Any): Column = withExpr { + UnresolvedExtractValue(expr, lit(extraction).expr) + } /** * Unary minus, i.e. negate the expression. @@ -97,7 +183,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_- : Column = UnaryMinus(expr) + def unary_- : Column = withExpr { UnaryMinus(expr) } /** * Inversion of boolean expression, i.e. NOT. @@ -113,7 +199,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_! : Column = Not(expr) + def unary_! : Column = withExpr { Not(expr) } /** * Equality test. @@ -129,7 +215,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def === (other: Any): Column = { + def === (other: Any): Column = withExpr { val right = lit(other).expr if (this.expr == right) { logWarning( @@ -170,7 +256,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + def !== (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } /** * Inequality test. @@ -187,7 +273,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, lit(other).expr)) } /** * Greater than. @@ -203,7 +289,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def > (other: Any): Column = GreaterThan(expr, lit(other).expr) + def > (other: Any): Column = withExpr { GreaterThan(expr, lit(other).expr) } /** * Greater than. @@ -234,7 +320,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def < (other: Any): Column = LessThan(expr, lit(other).expr) + def < (other: Any): Column = withExpr { LessThan(expr, lit(other).expr) } /** * Less than. @@ -264,7 +350,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr) + def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, lit(other).expr) } /** * Less than or equal to. @@ -294,7 +380,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr) + def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, lit(other).expr) } /** * Greater than or equal to an expression. @@ -317,7 +403,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr) + def <=> (other: Any): Column = withExpr { EqualNullSafe(expr, lit(other).expr) } /** * Equality test that is safe for null values. @@ -350,7 +436,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) + withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) } case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -380,7 +466,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { - CaseWhen(branches :+ lit(value).expr) + withExpr { CaseWhen(branches :+ lit(value).expr) } } else { throw new IllegalArgumentException( "otherwise() can only be applied once on a Column previously generated by when()") @@ -406,7 +492,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.5.0 */ - def isNaN: Column = IsNaN(expr) + def isNaN: Column = withExpr { IsNaN(expr) } /** * True if the current expression is null. @@ -414,7 +500,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNull: Column = IsNull(expr) + def isNull: Column = withExpr { IsNull(expr) } /** * True if the current expression is NOT null. @@ -422,7 +508,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNotNull: Column = IsNotNull(expr) + def isNotNull: Column = withExpr { IsNotNull(expr) } /** * Boolean OR. @@ -437,7 +523,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def || (other: Any): Column = Or(expr, lit(other).expr) + def || (other: Any): Column = withExpr { Or(expr, lit(other).expr) } /** * Boolean OR. @@ -467,7 +553,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def && (other: Any): Column = And(expr, lit(other).expr) + def && (other: Any): Column = withExpr { And(expr, lit(other).expr) } /** * Boolean AND. @@ -497,7 +583,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def + (other: Any): Column = Add(expr, lit(other).expr) + def + (other: Any): Column = withExpr { Add(expr, lit(other).expr) } /** * Sum of this expression and another expression. @@ -527,7 +613,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def - (other: Any): Column = Subtract(expr, lit(other).expr) + def - (other: Any): Column = withExpr { Subtract(expr, lit(other).expr) } /** * Subtraction. Subtract the other expression from this expression. @@ -557,7 +643,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def * (other: Any): Column = Multiply(expr, lit(other).expr) + def * (other: Any): Column = withExpr { Multiply(expr, lit(other).expr) } /** * Multiplication of this expression and another expression. @@ -587,7 +673,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def / (other: Any): Column = Divide(expr, lit(other).expr) + def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) } /** * Division this expression by another expression. @@ -610,7 +696,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def % (other: Any): Column = Remainder(expr, lit(other).expr) + def % (other: Any): Column = withExpr { Remainder(expr, lit(other).expr) } /** * Modulo (a.k.a. remainder) expression. @@ -626,9 +712,21 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 + * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. + */ + @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") + @scala.annotation.varargs + def in(list: Any*): Column = isin(list : _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + * + * @group expr_ops + * @since 1.5.0 */ @scala.annotation.varargs - def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) + def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** * SQL like expression. @@ -636,7 +734,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def like(literal: String): Column = Like(expr, lit(literal).expr) + def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** * SQL RLIKE expression (LIKE with Regex). @@ -644,7 +742,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + def rlike(literal: String): Column = withExpr { RLike(expr, lit(literal).expr) } /** * An expression that gets an item at position `ordinal` out of an array, @@ -653,7 +751,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key)) + def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } /** * An expression that gets a field by name in a [[StructType]]. @@ -661,7 +759,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName)) + def getField(fieldName: String): Column = withExpr { + UnresolvedExtractValue(expr, Literal(fieldName)) + } /** * An expression that returns a substring. @@ -671,7 +771,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr) + def substr(startPos: Column, len: Column): Column = withExpr { + Substring(expr, startPos.expr, len.expr) + } /** * An expression that returns a substring. @@ -681,7 +783,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr) + def substr(startPos: Int, len: Int): Column = withExpr { + Substring(expr, lit(startPos).expr, lit(len).expr) + } /** * Contains the other element. @@ -689,7 +793,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def contains(other: Any): Column = Contains(expr, lit(other).expr) + def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** * String starts with. @@ -697,7 +801,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr) + def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** * String starts with another string literal. @@ -713,7 +817,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr) + def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** * String ends with another string literal. @@ -742,10 +846,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as("colB")) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = Alias(expr, alias)() + def as(alias: String): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } + } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -757,7 +869,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Seq[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Assigns the given aliases to the results of a table generating function. @@ -769,7 +881,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Array[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Gives the column an alias. @@ -778,10 +890,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as('colB)) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = Alias(expr, alias.name)() + def as(alias: Symbol): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } + } /** * Gives the column an alias with metadata. @@ -793,7 +913,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String, metadata: Metadata): Column = { + def as(alias: String, metadata: Metadata): Column = withExpr { Alias(expr, alias)(explicitMetadata = Some(metadata)) } @@ -811,10 +931,12 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = expr match { - // Lift alias out of cast so we can support col.as("name").cast(IntegerType) - case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)() - case _ => Cast(expr, to) + def cast(to: DataType): Column = withExpr { + expr match { + // keeps the name of expression if possible when do cast. + case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) + case _ => Cast(expr, to) + } } /** @@ -844,7 +966,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def desc: Column = SortOrder(expr, Descending) + def desc: Column = withExpr { SortOrder(expr, Descending) } /** * Returns an ordering used in sorting. @@ -859,7 +981,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def asc: Column = SortOrder(expr, Ascending) + def asc: Column = withExpr { SortOrder(expr, Ascending) } /** * Prints the expression to the console for debugging purpose. @@ -886,7 +1008,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr) + def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, lit(other).expr) } /** * Compute bitwise AND of this expression with another expression. @@ -897,7 +1019,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr) + def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, lit(other).expr) } /** * Compute bitwise XOR of this expression with another expression. @@ -908,7 +1030,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } /** * Define a windowing column. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ea0f9ed3bdd..d74131231499 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,28 +20,27 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties -import org.apache.spark.unsafe.types.UTF8String - import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -54,7 +53,6 @@ private[sql] object DataFrame { } } - /** * :: Experimental :: * A distributed collection of data organized into named columns. @@ -113,11 +111,14 @@ private[sql] object DataFrame { * @groupname action Actions * @since 1.3.0 */ -// TODO: Improve documentation. @Experimental class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { + @transient override val sqlContext: SQLContext, + @DeveloperApi @transient override val queryExecution: QueryExecution) + extends Queryable with Serializable { + + // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure + // you wrap it with `withNewExecutionId` if this actions doesn't call other action. /** * A constructor that automatically analyzes the logical plan. @@ -146,14 +147,6 @@ class DataFrame private[sql]( queryExecution.analyzed } - /** - * An implicit conversion function internal to this class for us to avoid doing - * "new DataFrame(...)" everywhere. - */ - @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } - protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -168,17 +161,15 @@ class DataFrame private[sql]( } /** - * Internal API for Python + * Compose the string representing rows for output * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) - val sb = new StringBuilder val takeResult = take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." @@ -186,6 +177,7 @@ class DataFrame private[sql]( row.toSeq.map { cell => val str = cell match { case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") case array: Array[_] => array.mkString("[", ", ", "]") case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString @@ -194,59 +186,7 @@ class DataFrame private[sql]( }: Seq[String] } - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - - sb.append(sep) - - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - - sb.append(sep) - - // For Data that has more than "numRows" records - if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows ${rowsString}\n") - } - - sb.toString() - } - - override def toString: String = { - try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } + formatString ( rows, numRows, hasMoreData, truncate ) } /** @@ -258,6 +198,16 @@ class DataFrame private[sql]( // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = this + /** + * :: Experimental :: + * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the + * specified type, `U`. + * @group basic + * @since 1.6.0 + */ + @Experimental + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) + /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: @@ -290,51 +240,49 @@ class DataFrame private[sql]( def schema: StructType = queryExecution.analyzed.schema /** - * Returns all column names and their data types as an array. + * Prints the schema to the console in a nice tree format. * @group basic * @since 1.3.0 */ - def dtypes: Array[(String, String)] = schema.fields.map { field => - (field.name, field.dataType.toString) - } + // scalastyle:off println + override def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println /** - * Returns all column names as an array. + * Prints the plans (logical and physical) to the console for debugging purposes. * @group basic * @since 1.3.0 */ - def columns: Array[String] = schema.fields.map(_.name) + override def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } /** - * Prints the schema to the console in a nice tree format. - * @group basic + * Prints the physical plan to the console for debugging purposes. * @since 1.3.0 */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println + override def explain(): Unit = explain(extended = false) /** - * Prints the plans (logical and physical) to the console for debugging purposes. + * Returns all column names and their data types as an array. * @group basic * @since 1.3.0 */ - def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } + def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) } /** - * Only prints the physical plan to the console for debugging purposes. + * Returns all column names as an array. * @group basic * @since 1.3.0 */ - def explain(): Unit = explain(extended = false) + def columns: Array[String] = schema.fields.map(_.name) /** * Returns true if the `collect` and `take` methods can be run locally @@ -360,7 +308,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = show(numRows, true) + def show(numRows: Int): Unit = show(numRows, truncate = true) /** * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters @@ -435,7 +383,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def join(right: DataFrame): DataFrame = { + def join(right: DataFrame): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } @@ -484,27 +432,66 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join] - // Project only one of the join columns. - val joinedCols = usingColumns.map(col => joined.right.resolve(col)) val condition = usingColumns.map { col => - catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + catalyst.expressions.EqualTo( + withPlan(joined.left).resolve(col), + withPlan(joined.right).resolve(col)) }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => catalyst.expressions.And(cond, eqTo) } - Project( - joined.output.filterNot(joinedCols.contains(_)), - Join( - joined.left, - joined.right, - joinType = Inner, - condition) - ) + // Project only one of the join columns. + val joinedCols = JoinType(joinType) match { + case Inner | LeftOuter | LeftSemi => + usingColumns.map(col => withPlan(joined.left).resolve(col)) + case RightOuter => + usingColumns.map(col => withPlan(joined.right).resolve(col)) + case FullOuter => + usingColumns.map { col => + val leftCol = withPlan(joined.left).resolve(col) + val rightCol = withPlan(joined.right).resolve(col) + Alias(Coalesce(Seq(leftCol, rightCol)), col)() + } + } + // The nullability of output of joined could be different than original column, + // so we can only compare them by exprId + val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) + val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) + withPlan { + Project( + resultCols, + Join( + joined.left, + joined.right, + joinType = JoinType(joinType), + condition) + ) + } } /** @@ -551,19 +538,20 @@ class DataFrame private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + val plan = withPlan( + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return plan + return withPlan(plan) } // If left/right have no output set intersection, return the plan. - val lanalyzed = this.logicalPlan.queryExecution.analyzed - val ranalyzed = right.logicalPlan.queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return plan + return withPlan(plan) } // Otherwise, find the trivially true predicates and automatically resolves them to both sides. @@ -572,9 +560,40 @@ class DataFrame private[sql]( val cond = plan.condition.map { _.transform { case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + catalyst.expressions.EqualTo( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} - plan.copy(condition = cond) + + withPlan { + plan.copy(condition = cond) + } + } + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) + } + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortExprs: Column*): DataFrame = { + sortInternal(global = false, sortExprs) } /** @@ -603,15 +622,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortExprs: Column*): DataFrame = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = true, logicalPlan) + sortInternal(global = true, sortExprs) } /** @@ -634,6 +645,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ @@ -641,12 +653,13 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ def col(colName: String): Column = colName match { case "*" => - Column(ResolvedStar(schema.fieldNames.map(resolve))) + Column(ResolvedStar(queryExecution.analyzed.output)) case _ => val expr = resolve(colName) Column(expr) @@ -657,7 +670,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def as(alias: String): DataFrame = Subquery(alias, logicalPlan) + def as(alias: String): DataFrame = withPlan { + Subquery(alias, logicalPlan) + } /** * (Scala-specific) Returns a new [[DataFrame]] with an alias set. @@ -666,6 +681,20 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) + /** + * Returns a new [[DataFrame]] with an alias set. Same as `as`. + * @group dfops + * @since 1.6.0 + */ + def alias(alias: String): DataFrame = as(alias) + + /** + * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`. + * @group dfops + * @since 1.6.0 + */ + def alias(alias: Symbol): DataFrame = as(alias) + /** * Selects a set of column based expressions. * {{{ @@ -675,21 +704,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = { - val namedExpressions = cols.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analzyer will generate the - // correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - // When user continuously call `select`, speed up analysis by collapsing `Project` - import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing - Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan)) + def select(cols: Column*): DataFrame = withPlan { + Project(cols.map(_.named), logicalPlan) } /** @@ -712,7 +728,9 @@ class DataFrame private[sql]( * SQL expressions. * * {{{ + * // The following are equivalent: * df.selectExpr("colA", "colB as newName", "abs(colC)") + * df.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) * }}} * @group dfops * @since 1.3.0 @@ -720,7 +738,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -734,7 +752,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan) + def filter(condition: Column): DataFrame = withPlan { + Filter(condition.expr, logicalPlan) + } /** * Filters rows using the given SQL expression. @@ -745,7 +765,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +789,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -975,7 +995,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + def limit(n: Int): DataFrame = withPlan { + Limit(Literal(n), logicalPlan) + } /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. @@ -983,7 +1005,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + def unionAll(other: DataFrame): DataFrame = withPlan { + Union(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. @@ -991,7 +1015,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + def intersect(other: DataFrame): DataFrame = withPlan { + Intersect(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. @@ -999,7 +1025,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + def except(other: DataFrame): DataFrame = withPlan { + Except(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -1010,7 +1038,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan) } @@ -1038,7 +1066,7 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) }.toArray } @@ -1089,7 +1117,8 @@ class DataFrame private[sql]( def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val elementTypes = schema.toAttributes.map { + attr => (attr.dataType, attr.nullable, attr.name) } val names = schema.toAttributes.map(_.name) val convert = CatalystTypeConverters.createToCatalystConverter(schema) @@ -1097,8 +1126,10 @@ class DataFrame private[sql]( f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } /** @@ -1117,8 +1148,7 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? - val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } - val names = attributes.map(_.name) + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } def rowFunction(row: Row): TraversableOnce[InternalRow] = { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) @@ -1126,14 +1156,17 @@ class DataFrame private[sql]( } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } ///////////////////////////////////////////////////////////////////////////// /** - * Returns a new [[DataFrame]] by adding a column. + * Returns a new [[DataFrame]] by adding a column or replacing the existing column that has + * the same name. * @group dfops * @since 1.3.0 */ @@ -1151,6 +1184,23 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. @@ -1159,13 +1209,17 @@ class DataFrame private[sql]( */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldRename = schema.exists(f => resolver(f.name, existingName)) + val output = queryExecution.analyzed.output + val shouldRename = output.exists(f => resolver(f.name, existingName)) if (shouldRename) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + val columns = output.map { col => + if (resolver(col.name, existingName)) { + Column(col).as(newName) + } else { + Column(col) + } } - select(colNames : _*) + select(columns : _*) } else { this } @@ -1178,16 +1232,24 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { + drop(Seq(colName) : _*) + } + + /** + * Returns a new [[DataFrame]] with columns dropped. + * This is a no-op if schema doesn't contain column name(s). + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def drop(colNames: String*): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldDrop = schema.exists(f => resolver(f.name, colName)) - if (shouldDrop) { - val colsAfterDrop = schema.filter { field => - val name = field.name - !resolver(name, colName) - }.map(f => Column(f.name)) - select(colsAfterDrop : _*) - } else { + val remainingCols = + schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) + if (remainingCols.size == this.schema.size) { this + } else { + this.select(remainingCols: _*) } } @@ -1200,9 +1262,14 @@ class DataFrame private[sql]( * @since 1.4.1 */ def drop(col: Column): DataFrame = { + val expression = col match { + case Column(u: UnresolvedAttribute) => + queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u) + case Column(expr: Expression) => expr + } val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => - attr != col.expr + attr != expression }.map(attr => Column(attr)) select(colsAfterDrop : _*) } @@ -1222,14 +1289,14 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(colNames: Seq[String]): DataFrame = { + def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan { val groupCols = colNames.map(resolve) val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(First(attr), attr.name)() + Alias(new First(attr).toAggregateExpression(), attr.name)() } } Aggregate(groupCols, aggCols, logicalPlan) @@ -1268,19 +1335,15 @@ class DataFrame private[sql]( * @since 1.3.1 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = { - - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + def describe(cols: String*): DataFrame = withPlan { // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( - "count" -> Count, - "mean" -> Average, - "stddev" -> stddevExpr, - "min" -> Min, - "max" -> Max) + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList @@ -1311,7 +1374,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def head(n: Int): Array[Row] = limit(n).collect() + def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df => + df.collect(needCallback = false) + } /** * Returns the first row. @@ -1327,6 +1392,19 @@ class DataFrame private[sql]( */ def first(): Row = head() + /** + * Concise syntax for chaining custom transformations. + * {{{ + * def featurize(ds: DataFrame) = ... + * + * df + * .transform(featurize) + * .transform(...) + * }}} + * @since 1.6.0 + */ + def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) + /** * Returns a new RDD by applying a function to all rows of this DataFrame. * @group rdd @@ -1356,52 +1434,127 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - def foreach(f: Row => Unit): Unit = rdd.foreach(f) + def foreach(f: Row => Unit): Unit = withNewExecutionId { + rdd.foreach(f) + } /** * Applies a function f to each partition of this [[DataFrame]]. * @group rdd * @since 1.3.0 */ - def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { + rdd.foreachPartition(f) + } /** * Returns the first `n` rows in the [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ def take(n: Int): Array[Row] = head(n) + /** + * Returns the first `n` rows in the [[DataFrame]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.6.0 + */ + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * * @group action * @since 1.3.0 */ - def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + def collect(): Array[Row] = collect(needCallback = true) /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } + } + + private def collect(needCallback: Boolean): Array[Row] = { + def execute(): Array[Row] = withNewExecutionId { + queryExecution.executedPlan.executeCollectPublic() + } + + if (needCallback) { + withCallback("collect", this)(_ => execute()) + } else { + execute() + } + } /** * Returns the number of rows in the [[DataFrame]]. * @group action * @since 1.3.0 */ - def count(): Long = groupBy().count().collect().head.getLong(0) + def count(): Long = withCallback("count", groupBy().count()) { df => + df.collect(needCallback = false).head.getLong(0) + } /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. - * @group rdd + * @group dfops * @since 1.3.0 */ - def repartition(numPartitions: Int): DataFrame = { + def repartition(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = true, logicalPlan) } + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) + } + + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def repartition(partitionExprs: Column*): DataFrame = withPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + } + /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. @@ -1410,7 +1563,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - def coalesce(numPartitions: Int): DataFrame = { + def coalesce(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -1423,6 +1576,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1432,12 +1586,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1447,6 +1606,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1456,6 +1617,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ @@ -1523,7 +1685,7 @@ class DataFrame private[sql]( */ def toJSON: RDD[String] = { val rowSchema = this.schema - this.mapPartitions { iter => + queryExecution.toRdd.mapPartitions { iter => val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) @@ -1554,10 +1716,10 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: HadoopFsRelation) => - fsBasedRelation.paths.toSeq - case LogicalRelation(jsonRelation: JSONRelation) => - jsonRelation.path.toSeq + case LogicalRelation(fsBasedRelation: FileRelation, _) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles }.flatten files.toSet.toArray } @@ -1575,6 +1737,12 @@ class DataFrame private[sql]( EvaluatePython.javaToPython(rdd) } + protected[sql] def collectToPython(): Int = { + withNewExecutionId { + PythonRDD.collectAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // Deprecated methods @@ -1582,9 +1750,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `toDF()`. + * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. */ - @deprecated("use toDF", "1.3.0") + @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") def toSchemaRDD: DataFrame = this /** @@ -1594,9 +1762,9 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. + * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write w.jdbc(url, table, new Properties) @@ -1613,11 +1781,11 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - val w = if (overwrite) write.mode(SaveMode.Overwrite) else write + val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) } @@ -1626,9 +1794,9 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. + * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.parquet(path)", "1.4.0") + @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") def saveAsParquetFile(path: String): Unit = { write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } @@ -1643,12 +1811,17 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String): Unit = { write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } @@ -1662,12 +1835,18 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { write.mode(mode).saveAsTable(tableName) } @@ -1682,12 +1861,18 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { write.format(source).saveAsTable(tableName) } @@ -1702,12 +1887,18 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).saveAsTable(tableName) } @@ -1721,14 +1912,19 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1747,14 +1943,19 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1768,9 +1969,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. + * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use write.save(path)", "1.4.0") + @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String): Unit = { write.save(path) } @@ -1780,8 +1981,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default. * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") + @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, mode: SaveMode): Unit = { write.mode(mode).save(path) } @@ -1791,8 +1993,9 @@ class DataFrame private[sql]( * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).save(path)", "1.4.0") + @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String): Unit = { write.format(source).save(path) } @@ -1802,8 +2005,10 @@ class DataFrame private[sql]( * [[SaveMode]] specified by mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).save(path). " + + "This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).save(path) } @@ -1814,8 +2019,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1830,8 +2037,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1839,14 +2048,15 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } @@ -1857,8 +2067,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String): Unit = { write.mode(SaveMode.Append).insertInto(tableName) } @@ -1869,4 +2081,52 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// + /** + * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with + * an execution. + */ + private[sql] def withNewExecutionId[T](body: => T): T = { + SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) + } + + /** + * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the + * user-registered callback functions. + */ + private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { + try { + df.queryExecution.executedPlan.foreach { plan => + plan.metrics.valuesIterator.foreach(_.reset()) + } + val start = System.nanoTime() + val result = action(df) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, df.queryExecution, e) + throw e + } + } + + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + withPlan { + Sort(sortOrder, global = global, logicalPlan) + } + } + + /** A convenient function to wrap a logical plan and produce a DataFrame. */ + @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index 2f19ec040301..3b30337f1f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql /** * A container for a [[DataFrame]], used for implicit conversions. * + * To use this, import implicit conversions in SQL: + * {{{ + * import sqlContext.implicits._ + * }}} + * * @since 1.3.0 */ -private[sql] case class DataFrameHolder(df: DataFrame) { +case class DataFrameHolder private[sql](private val df: DataFrame) { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ea85f0657a72..f7be5f6b370a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.{lang => jl} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } @@ -198,7 +198,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: + * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -209,13 +210,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -232,7 +233,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ @@ -254,12 +256,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { - replace[T](col, replacement.toMap : Map[T, T]) + replace[T](col, replacement.asScala.toMap) } /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -277,13 +280,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { - replace(cols.toSeq, replacement.toMap) + replace(cols.toSeq, replacement.asScala.toMap) } /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. + * If `col` is "*", + * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -311,7 +316,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles , strings or booleans. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -333,15 +339,17 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] + // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] val replacementMap: Map[_, _] = replacement.head._2 match { case v: String => replacement + case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } } - // targetColumnType is either DoubleType or StringType + // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: jl.Boolean => BooleanType case _: String => StringType } @@ -367,7 +375,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // Check data type replaceValue match { - case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String => // This is good case _ => throw new IllegalArgumentException( s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") @@ -382,6 +390,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Double => fillCol[Double](f, v) case v: jl.Long => fillCol[Double](f, v.toDouble) case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } }.getOrElse(df.col(f.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index eb09807f9d9c..c1a8f19313a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -19,18 +19,22 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} /** * :: Experimental :: @@ -90,7 +94,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } @@ -100,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -120,6 +125,17 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { DataFrame(sqlContext, LogicalRelation(resolved.relation)) } + /** + * Loads input in as a [[DataFrame]], for data sources that support multiple paths. + * Only works if the source is a HadoopFsRelationProvider. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def load(paths: String*): DataFrame = { + option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() + } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. @@ -197,7 +213,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { table: String, parts: Array[Partition], connectionProperties: Properties): DataFrame = { - val relation = JDBCRelation(url, table, parts, connectionProperties)(sqlContext) + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + val relation = JDBCRelation(url, table, parts, props)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } @@ -207,11 +229,39 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. * - * @param path input path + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
      • `primitivesAsString` (default `false`): infers all primitive values as a string type
      • + *
      • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
      • + *
      • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
      • + *
      • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
      • + *
      • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
      • + * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
      • `primitivesAsString` (default `false`): infers all primitive values as a string type
      • + *
      • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
      • + *
      • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
      • + *
      • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
      • + *
      • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
      • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -235,9 +285,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble sqlContext.baseRelationToDataFrame( - new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) + new JSONRelation( + Some(jsonRDD), + maybeDataSchema = userSpecifiedSchema, + maybePartitionSpec = None, + userDefinedPartitionColumns = None, + parameters = extraOptions.toMap)(sqlContext) + ) } /** @@ -260,7 +315,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { sqlContext.baseRelationToDataFrame( new ParquetRelation( - globbedPaths.map(_.toString), None, None, extraOptions.toMap)(sqlContext)) + globbedPaths.map(_.toString), userSpecifiedSchema, None, extraOptions.toMap)(sqlContext)) } } @@ -279,9 +334,27 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(Seq(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } + /** + * Loads a text file and returns a [[DataFrame]] with a single string column named "value". + * Each line in the text file is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * sqlContext.read.text("/path/to/spark/README.md") + * + * // Java: + * sqlContext.read().text("/path/to/spark/README.md") + * }}} + * + * @param paths input path + * @since 1.6.0 + */ + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 2e68e358f2f1..69c984717526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -39,6 +39,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the second column * @return the covariance of the two columns. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.cov("rand1", "rand2") + * res1: Double = 0.065... + * }}} + * * @since 1.4.0 */ def cov(col1: String, col2: String): Double = { @@ -54,6 +61,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the column to calculate the correlation against * @return The Pearson Correlation Coefficient as a Double. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2") + * res1: Double = 0.613... + * }}} + * * @since 1.4.0 */ def corr(col1: String, col2: String, method: String): Double = { @@ -69,6 +83,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the column to calculate the correlation against * @return The Pearson Correlation Coefficient as a Double. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2", "pearson") + * res1: Double = 0.613... + * }}} + * * @since 1.4.0 */ def corr(col1: String, col2: String): Double = { @@ -92,6 +113,20 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * of the DataFrame. * @return A DataFrame containing for the contingency table. * + * {{{ + * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * (3, 3))).toDF("key", "value") + * val ct = df.stat.crosstab("key", "value") + * ct.show() + * +---------+---+---+---+ + * |key_value| 1| 2| 3| + * +---------+---+---+---+ + * | 2| 2| 0| 1| + * | 1| 1| 1| 0| + * | 3| 0| 1| 1| + * +---------+---+---+---+ + * }}} + * * @since 1.4.0 */ def crosstab(col1: String, col2: String): DataFrame = { @@ -112,6 +147,32 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * than 1e-4. * @return A Local DataFrame with the Array of frequent items for each column. * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * * @since 1.4.0 */ def freqItems(cols: Array[String], support: Double): DataFrame = { @@ -147,6 +208,32 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * * @since 1.4.0 */ def freqItems(cols: Seq[String], support: Double): DataFrame = { @@ -180,6 +267,20 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @tparam T stratum type * @return a new [[DataFrame]] that represents the stratified sample * + * {{{ + * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * (3, 3))).toDF("key", "value") + * val fractions = Map(1 -> 1.0, 3 -> 0.5) + * df.stat.sampleBy("key", fractions, 36L).show() + * +---+-----+ + * |key|value| + * +---+-----+ + * | 1| 1| + * | 1| 2| + * | 3| 2| + * +---+-----+ + * }}} + * * @since 1.5.0 */ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7e3318cefe62..03867beb7822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} +import org.apache.spark.sql.sources.HadoopFsRelation /** @@ -108,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } @@ -160,21 +163,42 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { - val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite + + // A partitioned relation's schema can be different from the input logicalPlan, since + // partition columns are all moved after data columns. We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => + parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + df.sqlContext.executePlan( InsertIntoTable( - UnresolvedRelation(tableIdent.toSeq), + UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, + input, overwrite, ifNotExists = false)).toRdd } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => + parCols.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + /** * Saves the content of the [[DataFrame]] as the specified table. * @@ -185,14 +209,20 @@ final class DataFrameWriter private[sql](df: DataFrame) { * When `mode` is `Append`, the schema of the [[DataFrame]] need to be * the same as that of the existing table, and format or options will be ignored. * + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val tableExists = df.sqlContext.catalog.tableExists(tableIdent.toSeq) + val tableExists = df.sqlContext.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => @@ -211,7 +241,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { case _ => val cmd = CreateTableUsingAsSelect( - tableIdent.unquotedString, + tableIdent, source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), @@ -235,12 +265,20 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. + * + * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { - val conn = JdbcUtils.createConnection(url, connectionProperties) + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return @@ -257,15 +295,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Create the table if the table didn't exist. if (!tableExists) { - val schema = JDBCWriteDetails.schemaString(df, url) + val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() + conn.createStatement.executeUpdate(sql) } } finally { conn.close() } - JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + JdbcUtils.saveTable(df, url, table, props) } /** @@ -302,6 +340,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def orc(path: String): Unit = format("orc").save(path) + /** + * Saves the content of the [[DataFrame]] in a text file at the specified path. + * The DataFrame must have only one column that is of string type. + * Each row becomes a new line in the output file. For example: + * {{{ + * // Scala: + * df.write.text("/path/to/output") + * + * // Java: + * df.write().text("/path/to/output") + * }}} + * + * @since 1.6.0 + */ + def text(path: String): Unit = format("text").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala new file mode 100644 index 000000000000..79b4244ac0cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -0,0 +1,792 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{Queryable, QueryExecution} +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel + * using functional or relational operations. + * + * A [[Dataset]] differs from an [[RDD]] in the following ways: + * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored + * in the encoded form. This representation allows for additional logical operations and + * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to + * an object. + * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be + * used to serialize the object into a binary format. Encoders are also capable of mapping the + * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime + * reflection based serialization. Operations that change the type of object stored in the + * dataset also need an encoder for the new type. + * + * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific + * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into + * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed + * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`. + * + * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However, + * making this change to the class hierarchy would break the function signatures for the existing + * functional operations (map, flatMap, etc). As such, this class should be considered a preview + * of the final API. Changes will be made to the interface after Spark 1.6. + * + * @since 1.6.0 + */ +@Experimental +class Dataset[T] private[sql]( + @transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution, + tEncoder: Encoder[T]) extends Queryable with Serializable { + + /** + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). + */ + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + + /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of this [[Dataset]]'s output schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + + private implicit def classTag = resolvedTEncoder.clsTag + + private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = + this(sqlContext, new QueryExecution(sqlContext, plan), encoder) + + /** + * Returns the schema of the encoded form of the objects in this [[Dataset]]. + * @since 1.6.0 + */ + override def schema: StructType = resolvedTEncoder.schema + + /** + * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format. + * @since 1.6.0 + */ + override def printSchema(): Unit = toDF().printSchema() + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(extended: Boolean): Unit = toDF().explain(extended) + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(): Unit = toDF().explain() + + /* ************* * + * Conversions * + * ************* */ + + /** + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The + * method used to map columns depend on the type of `U`: + * - When `U` is a class, fields for the class will be mapped to columns of the same name + * (case sensitivity is determined by `spark.sql.caseSensitive`) + * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * be assigned to `_1`). + * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the + * [[DataFrame]] will be used. + * + * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select` + * along with `alias` or `as` to rearrange or rename as required. + * @since 1.6.0 + */ + def as[U : Encoder]: Dataset[U] = { + new Dataset(sqlContext, queryExecution, encoderFor[U]) + } + + /** + * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have + * the same name after two Datasets have been joined. + * @since 1.6.0 + */ + def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) + + /** + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] + * objects that allow fields to be accessed by ordinal or name. + */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) + + /** + * Returns this [[Dataset]]. + * @since 1.6.0 + */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset. + def toDS(): Dataset[T] = this + + /** + * Converts this [[Dataset]] to an [[RDD]]. + * @since 1.6.0 + */ + def rdd: RDD[T] = { + queryExecution.toRdd.mapPartitions { iter => + iter.map(boundTEncoder.fromRow) + } + } + + /** + * Returns the number of elements in the [[Dataset]]. + * @since 1.6.0 + */ + def count(): Long = toDF().count() + + /** + * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[Dataset]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[Dataset]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + // scalastyle:off println + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println + + /** + * Compose the string representing rows for output + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right + */ + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map { + case r: Row => r + case tuple: Product => Row.fromTuple(tuple) + case o => Row(o) + } map { row => + row.toSeq.map { cell => + val str = cell match { + case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + }) + + formatString ( rows, numRows, hasMoreData, truncate ) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * @since 1.6.0 + */ + def repartition(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = true, _) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @since 1.6.0 + */ + def coalesce(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = false, _) + } + + /* *********************** * + * Functional Operations * + * *********************** */ + + /** + * Concise syntax for chaining custom transformations. + * {{{ + * def featurize(ds: Dataset[T]) = ... + * + * dataset + * .transform(featurize) + * .transform(...) + * }}} + * @since 1.6.0 + */ + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = + map(t => func.call(t))(encoder) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * @since 1.6.0 + */ + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapPartitions[T, U]( + func, + resolvedTEncoder, + encoderFor[U], + encoderFor[U].schema.toAttributes, + logicalPlan)) + } + + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * @since 1.6.0 + */ + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala + mapPartitions(func)(encoder) + } + + /** + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = + mapPartitions(_.flatMap(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterable[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } + + /* ************** * + * Side effects * + * ************** */ + + /** + * (Scala-specific) + * Runs `func` on each element of this [[Dataset]]. + * @since 1.6.0 + */ + def foreach(func: T => Unit): Unit = rdd.foreach(func) + + /** + * (Java-specific) + * Runs `func` on each element of this [[Dataset]]. + * @since 1.6.0 + */ + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) + + /** + * (Scala-specific) + * Runs `func` on each partition of this [[Dataset]]. + * @since 1.6.0 + */ + def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + + /** + * (Java-specific) + * Runs `func` on each partition of this [[Dataset]]. + * @since 1.6.0 + */ + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + foreachPartition(it => func.call(it.asJava)) + + /* ************* * + * Aggregation * + * ************* */ + + /** + * (Scala-specific) + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: (T, T) => T): T = rdd.reduce(func) + + /** + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) + + /** + * (Scala-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * @since 1.6.0 + */ + def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { + val inputPlan = logicalPlan + val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedDataset( + encoderFor[K], + encoderFor[T], + executed, + inputPlan.output, + withGroupingKey.newColumns) + } + + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * @since 1.6.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataset[Row, T] = { + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKey = Project(withKeyColumns, logicalPlan) + val executed = sqlContext.executePlan(withKey) + + val dataAttributes = executed.analyzed.output.dropRight(cols.size) + val keyAttributes = executed.analyzed.output.takeRight(cols.size) + + new GroupedDataset( + RowEncoder(keyAttributes.toStructType), + encoderFor[T], + executed, + dataAttributes, + keyAttributes) + } + + /** + * (Java-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * @since 1.6.0 + */ + def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(func.call(_))(encoder) + + /* ****************** * + * Typed Relational * + * ****************** */ + + /** + * Returns a new [[DataFrame]] by selecting a set of column based expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + * @since 1.6.0 + */ + // Copied from Dataframe to make sure we don't have invalid overloads. + @scala.annotation.varargs + protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(expr("value + 1").as[Int]) + * }}} + * @since 1.6.0 + */ + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + boundTEncoder, + logicalPlan.output).named :: Nil, + logicalPlan)) + } + + /** + * Internal helper function for building typed selects that return tuples. For simplicity and + * code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) + + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + + /** + * Returns a new [[Dataset]] by sampling a fraction of records. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + + /** + * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + /* **************** * + * Set operations * + * **************** */ + + /** + * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]]. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def distinct: Dataset[T] = withPlan(Distinct) + + /** + * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also + * present in `other`. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) + + /** + * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. + * @since 1.6.0 + */ + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) + + /** + * Returns a new [[Dataset]] where any elements present in `other` have been removed. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + + /* ****** * + * Joins * + * ****** */ + + /** + * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * true. + * + * This is similar to the relation `join` function with one important difference in the + * result schema. Since `joinWith` preserves objects present on either side of the join, the + * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * + * This type of join can be useful both for preserving type-safety with the original object + * types as well as working with relational data where either side of the join has column + * names in common. + * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { + val left = this.logicalPlan + val right = other.logicalPlan + + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + + val leftData = this.unresolvedTEncoder match { + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() + } + val rightData = other.unresolvedTEncoder match { + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() + } + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + withPlan[(T, U)](other) { (left, right) => + Project( + leftData :: rightData :: Nil, + joined.analyzed) + } + } + + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + + /* ************************** * + * Gather to Driver Actions * + * ************************** */ + + /** + * Returns the first element in this [[Dataset]]. + * @since 1.6.0 + */ + def first(): T = take(1).head + + /** + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * @since 1.6.0 + */ + def collect(): Array[T] = { + // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders + // to convert the rows into objects of type T. + queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + } + + /** + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * @since 1.6.0 + */ + def collectAsList(): java.util.List[T] = collect().toSeq.asJava + + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `num` can crash the driver process with OutOfMemoryError. + * @since 1.6.0 + */ + def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() + + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `num` can crash the driver process with OutOfMemoryError. + * @since 1.6.0 + */ + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @since 1.6.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + + /* ******************** * + * Internal Functions * + * ******************** */ + + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed + + private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) + + private[sql] def withPlan[R : Encoder]( + other: Dataset[_])( + f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = + new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala new file mode 100644 index 000000000000..08097e9f0208 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * A container for a [[Dataset]], used for implicit conversions. + * + * To use this, import implicit conversions in SQL: + * {{{ + * import sqlContext.implicits._ + * }}} + * + * @since 1.6.0 + */ +case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. + def toDS(): Dataset[T] = ds +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 99d557b03a03..13341a88a6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -17,58 +17,31 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType -/** - * Companion object for GroupedData - */ -private[sql] object GroupedData { - def apply( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) - } - - /** - * The Grouping Type - */ - private[sql] trait GroupType - - /** - * To indicate it's the GroupBy - */ - private[sql] object GroupByType extends GroupType - - /** - * To indicate it's the CUBE - */ - private[sql] object CubeType extends GroupType - - /** - * To indicate it's the ROLLUP - */ - private[sql] object RollupType extends GroupType -} /** * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. * + * The main method is the agg function, which has multiple variants. This class also contains + * convenience some first order statistics such as mean, sum for convenience. + * * @since 1.3.0 */ @Experimental class GroupedData protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], - private val groupType: GroupedData.GroupType) { + groupType: GroupedData.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { @@ -77,14 +50,8 @@ class GroupedData protected[sql]( aggExprs } - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } + val aliasedAgg = aggregates.map(alias) + groupType match { case GroupedData.GroupByType => DataFrame( @@ -95,10 +62,23 @@ class GroupedData protected[sql]( case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + DataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { val columnExprs = if (colNames.isEmpty) { @@ -116,22 +96,28 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map(f)) + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) } private[this] def strToExpr(expr: String): (Expression => Expression) = { - expr.toLowerCase match { - case "avg" | "average" | "mean" => Average - case "max" => Max - case "min" => Min - case "sum" => Sum - case "count" | "size" => - // Turn count(*) into count(1) - (inputExpr: Expression) => inputExpr match { - case s: Star => Count(Literal(1)) - case _ => Count(inputExpr) - } + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } } + (inputExpr: Expression) => exprToFunc(inputExpr) } /** @@ -188,7 +174,7 @@ class GroupedData protected[sql]( * @since 1.3.0 */ def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.toMap) + agg(exprs.asScala.toMap) } /** @@ -233,7 +219,7 @@ class GroupedData protected[sql]( * * @since 1.3.0 */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. @@ -294,4 +280,136 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): GroupedData = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + groupType match { + case GroupedData.GroupByType => + new GroupedData( + df, + groupingExprs, + GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + pivot(pivotColumn, values.asScala) + } +} + + +/** + * Companion object for GroupedData. + */ +private[sql] object GroupedData { + + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): GroupedData = { + new GroupedData(df, groupingExprs, groupType: GroupType) + } + + /** + * The Grouping Type + */ + private[sql] trait GroupType + + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType + + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala new file mode 100644 index 000000000000..4bf0b256fcb4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator + +/** + * :: Experimental :: + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing + * [[Dataset]]. + * + * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, + * making this change to the class hierarchy would break some function signatures. As such, this + * class should be considered a preview of the final API. Changes will be made to the interface + * after Spark 1.6. + * + * @since 1.6.0 + */ +@Experimental +class GroupedDataset[K, V] private[sql]( + kEncoder: Encoder[K], + vEncoder: Encoder[V], + val queryExecution: QueryExecution, + private val dataAttributes: Seq[Attribute], + private val groupingAttributes: Seq[Attribute]) extends Serializable { + + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. + + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) + + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + + private def logicalPlan = queryExecution.analyzed + private def sqlContext = queryExecution.sqlContext + + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + + /** + * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified + * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * + * @since 1.6.0 + */ + def keyAs[L : Encoder]: GroupedDataset[L, V] = + new GroupedDataset( + encoderFor[L], + unresolvedVEncoder, + queryExecution, + dataAttributes, + groupingAttributes) + + /** + * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 + */ + def keys: Dataset[K] = { + new Dataset[K]( + sqlContext, + Distinct( + Project(groupingAttributes, logicalPlan))) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapGroups( + f, + resolvedKEncoder, + resolvedVEncoder, + groupingAttributes, + logicalPlan)) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroups(func) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) + + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduce(f.call _) + } + + // This is here to prevent us from adding overloads that would be ambiguous. + @scala.annotation.varargs + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) + + private def withEncoder(c: Column): Column = c match { + case tc: TypedColumn[_, _] => + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) + case _ => c + } + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggrecations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map( + _.withInputType(resolvedVEncoder, dataAttributes).named) + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) + groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + * + * @since 1.6.0 + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R : Encoder]( + other: GroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + new Dataset[R]( + sqlContext, + CoGroup( + f, + this.resolvedKEncoder, + this.resolvedVEncoder, + other.resolvedVEncoder, + this.groupingAttributes, + other.groupingAttributes, + this.logicalPlan, + other.logicalPlan)) + } + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R]( + other: GroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 387960c4b482..3d819262859f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.util.Properties import scala.collection.immutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter @@ -186,6 +186,16 @@ private[spark] object SQLConf { import SQLConfEntry._ + val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts", + defaultValue = Some(true), + doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed." + + "When set to false, only one SQLContext/HiveContext is allowed to be created " + + "through the constructor (new SQLContexts/HiveContexts created through newSession " + + "method is allowed). Please note that this conf needs to be set in Spark Conf. Once" + + "a SQLContext/HiveContext has been created, changing the value of this conf will not" + + "have effect.", + isPublic = true) + val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", defaultValue = Some(true), doc = "When set to true Spark SQL will automatically select a compression codec for each " + @@ -200,7 +210,7 @@ private[spark] object SQLConf { val IN_MEMORY_PARTITION_PRUNING = booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, enable partition pruning for in-memory columnar tables.", isPublic = false) @@ -223,14 +233,29 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), - doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.") + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = + longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", + defaultValue = Some(64 * 1024 * 1024), + doc = "The target post-shuffle input size in bytes of a task.") + + val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled", + defaultValue = Some(false), + doc = "When true, enable adaptive query execution.") + + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = + intConf("spark.sql.adaptive.minNumPostShufflePartitions", + defaultValue = Some(-1), + doc = "The advisory minimal number of post-shuffle partitions provided to " + + "ExchangeCoordinator. This setting is used in our test to make sure we " + + "have enough parallelism to expose issues that will not be exposed with a " + + "single partition. When the value is a non-positive value, this setting will" + + "not be provided to ExchangeCoordinator.", + isPublic = false) - val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", + val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", defaultValue = Some(true), - doc = "When true, use the new optimized Tungsten physical execution backend.") + doc = "When true, common subexpressions will be eliminated.", + isPublic = false) val DIALECT = stringConf( "spark.sql.dialect", @@ -283,12 +308,11 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "Enables Parquet filter push-down optimization when set to true.") - val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( - key = "spark.sql.parquet.followParquetFormatSpec", + val PARQUET_WRITE_LEGACY_FORMAT = booleanConf( + key = "spark.sql.parquet.writeLegacyFormat", defaultValue = Some(false), doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.", - isPublic = false) + "Spark SQL schema and vice versa.") val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( key = "spark.sql.parquet.output.committer.class", @@ -297,15 +321,19 @@ private[spark] object SQLConf { "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + "option must be set in Hadoop Configuration. 2. This option overrides " + - "\"spark.sql.sources.outputCommitterClass\"." - ) + "\"spark.sql.sources.outputCommitterClass\".") + + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableUnsafeRowRecordReader", + defaultValue = Some(true), + doc = "Enables using the custom ParquetUnsafeRowRecordReader.") val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", - defaultValue = Some(true), + defaultValue = Some(false), doc = "") val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", @@ -313,6 +341,15 @@ private[spark] object SQLConf { doc = "When true, some predicates will be pushed down into the Hive metastore so that " + "unmatching partitions can be eliminated earlier.") + val NATIVE_VIEW = booleanConf("spark.sql.nativeView", + defaultValue = Some(false), + doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + + "Note that this function is experimental and should ony be used when you are using " + + "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + + "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + + "possible, or you may get wrong result.", + isPublic = false) + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "") @@ -321,17 +358,6 @@ private[spark] object SQLConf { defaultValue = Some(5 * 60), doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") - // Options that control which operators can be chosen by the query planner. These should be - // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(true), - doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") - // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", doc = "Set a Fair Scheduler pool for a JDBC client session") @@ -359,17 +385,21 @@ private[spark] object SQLConf { "storing additional schema information in Hive's metastore.", isPublic = false) - // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), - doc = "When true, automtically discover data partitions.") + doc = "When true, automatically discover data partitions.") - // Whether to perform partition column type inference. Default to true. val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", defaultValue = Some(true), doc = "When true, automatically infer the data types for partitioned columns.") + val PARTITION_MAX_FILES = + intConf("spark.sql.sources.maxConcurrentWrites", + defaultValue = Some(5), + doc = "The maximum number of concurrent files to open before falling back on sorting when " + + "writing out files using dynamic partitioning.") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // @@ -406,19 +436,27 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) - val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", - defaultValue = Some(true), doc = "") - - val USE_SQL_SERIALIZER2 = booleanConf( - "spark.sql.useSerializer2", - defaultValue = Some(true), isPublic = false) + val DATAFRAME_PIVOT_MAX_VALUES = intConf( + "spark.sql.pivotMaxValues", + defaultValue = Some(10000), + doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error." + ) - val ADVANCED_SQL_OPTIMIZATION = booleanConf( - "spark.sql.advancedOptimization", - defaultValue = Some(true), isPublic = false) + val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", + defaultValue = Some(true), + isPublic = false, + doc = "When true, we could use `datasource`.`path` as table in SQL query" + ) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" + val TUNGSTEN_ENABLED = "spark.sql.tungsten.enabled" + val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" } } @@ -431,7 +469,6 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ - private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -466,6 +503,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + private[spark] def targetPostShuffleInputSize: Long = + getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) + + private[spark] def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + + private[spark] def minNumPostShufflePartitions: Int = + getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) + private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) @@ -474,21 +519,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) + private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) - - private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - - private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) + private[spark] def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) @@ -499,7 +535,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) + private[spark] def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) @@ -529,11 +565,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => setConfString(k, v) } + props.asScala.foreach { case (k, v) => setConfString(k, v) } } /** Set the given Spark SQL configuration property using a `string` value. */ @@ -603,24 +641,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Return all the configuration properties that have been set (i.e. not the default). * This creates a new copy of the config properties in the form of a Map. */ - def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } + def getAllConfs: immutable.Map[String, String] = + settings.synchronized { settings.asScala.toMap } /** * Return all the configuration definitions that have been defined in [[SQLConf]]. Each * definition contains key, defaultValue and doc. */ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { - sqlConfEntries.values.filter(_.isPublic).map { entry => + sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => (entry.key, entry.defaultValueString, entry.doc) }.toSeq } private[spark] def unsetConf(key: String): Unit = { - settings -= key + settings.remove(key) } private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { - settings -= entry.key + settings.remove(entry.key) } private[spark] def clear(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 31e2b508d485..db286ea8700b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,22 +17,22 @@ package org.apache.spark.sql -import java.beans.Introspector +import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -41,11 +41,13 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.optimizer.FilterNullsInJoinKey +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -58,22 +60,64 @@ import org.apache.spark.util.Utils * @groupname specificdata Specific Data Sources * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation - * @groupname Ungrouped Support functions for language integrated queries. + * @groupname Ungrouped Support functions for language integrated queries * * @since 1.0.0 */ -class SQLContext(@transient val sparkContext: SparkContext) - extends org.apache.spark.Logging - with Serializable { +class SQLContext private[sql]( + @transient val sparkContext: SparkContext, + @transient protected[sql] val cacheManager: CacheManager, + @transient private[sql] val listener: SQLListener, + val isRootContext: Boolean) + extends org.apache.spark.Logging with Serializable { self => + def this(sparkContext: SparkContext) = { + this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) + } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) + // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user + // wants to create a new root SQLContext (a SLQContext that is not created by newSession). + private val allowMultipleContexts = + sparkContext.conf.getBoolean( + SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, + SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get) + + // Assert no root SQLContext is running when allowMultipleContexts is false. + { + if (!allowMultipleContexts && isRootContext) { + SQLContext.getInstantiatedContextOption() match { + case Some(rootSQLContext) => + val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " + + s"It is recommended to use SQLContext.getOrCreate to get the instantiated " + + s"SQLContext/HiveContext. To ignore this error, " + + s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf." + throw new SparkException(errMsg) + case None => // OK + } + } + } + + /** + * Returns a SQLContext as new session, with separated SQL configurations, temporary tables, + * registered functions, but sharing the same SparkContext, CacheManager, SQLListener and SQLTab. + * + * @since 1.6.0 + */ + def newSession(): SQLContext = { + new SQLContext( + sparkContext = sparkContext, + cacheManager = cacheManager, + listener = listener, + isRootContext = false) + } + /** * @return Spark SQL configuration */ - protected[sql] def conf = currentSession().conf + protected[sql] lazy val conf = new SQLConf /** * Set Spark SQL configuration properties. @@ -135,13 +179,14 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - // TODO how to handle the temp table per user session? + @transient + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) - // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin + protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() @transient protected[sql] lazy val analyzer: Analyzer = @@ -149,7 +194,7 @@ class SQLContext(@transient val sparkContext: SparkContext) override val extendedResolutionRules = ExtractPythonUDFs :: PreInsertCastAndRename :: - Nil + (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) override val extendedCheckRules = Seq( datasources.PreWriteCheck(catalog) @@ -157,9 +202,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { - override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil - } + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) @@ -187,17 +230,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) - protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): + org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) - - @transient - protected[sql] val tlSession = new ThreadLocal[SQLSession]() { - override def initialValue: SQLSession = defaultSession - } - - @transient - protected[sql] val defaultSession = createSession() + protected[sql] def executePlan(plan: LogicalPlan) = + new sparkexecution.QueryExecution(this, plan) protected[sql] def dialectClassName = if (conf.dialect == "sql") { classOf[DefaultParserDialect].getCanonicalName @@ -205,6 +242,13 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } + /** + * Add a jar to SQLContext + */ + protected[sql] def addJar(path: String): Unit = { + sparkContext.addJar(path) + } + { // We extract spark sql settings from SparkContext's conf and put them to // Spark SQL's conf. @@ -224,14 +268,11 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.setConf(properties) // After we have populated SQLConf, we call setConf to populate other confs in the subclass // (e.g. hiveconf in HiveContext). - properties.foreach { + properties.asScala.foreach { case (key, value) => setConf(key, value) } } - @transient - protected[sql] val cacheManager = new CacheManager(this) - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -288,29 +329,39 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) - @transient - val udaf: UDAFRegistration = new UDAFRegistration(this) - /** * Returns true if the table is currently cached in-memory. * @group cachemgmt * @since 1.3.0 */ - def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + def isCached(tableName: String): Boolean = { + cacheManager.lookupCachedData(table(tableName)).nonEmpty + } + + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } /** * Caches the specified table in-memory. * @group cachemgmt * @since 1.3.0 */ - def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + def cacheTable(tableName: String): Unit = { + cacheManager.cacheQuery(table(tableName), Some(tableName)) + } /** * Removes the specified table from the in-memory cache. * @group cachemgmt * @since 1.3.0 */ - def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName)) /** * Removes all cached tables from the in-memory cache. @@ -334,108 +385,32 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self /** * Converts $"col name" into an [[Column]]. * @since 1.3.0 */ + // This must live here to preserve binary compatibility with Spark < 1.5. implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args : _*)) + new ColumnName(sc.s(args: _*)) } } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } } + // scalastyle:on /** * :: Experimental :: - * Creates a DataFrame from an RDD of case classes. + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). * * @group dataframes * @since 1.3.0 */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -451,7 +426,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -522,6 +497,29 @@ class SQLContext(@transient val sparkContext: SparkContext) DataFrame(this, logicalPlan) } + + def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { + val enc = encoderFor[T] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(attributes, encoded) + + new Dataset[T](this, plan) + } + + def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { + val enc = encoderFor[T] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d)) + val plan = LogicalRDD(attributes, encoded)(self) + + new Dataset[T](this, plan) + } + + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. @@ -548,6 +546,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * @@ -557,21 +569,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - val attributeSeq = getSchema(beanClass) + val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - val extractors = - localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) - val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => - (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) - } - iter.map { row => - new GenericInternalRow( - methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any] - ): InternalRow - } + SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) } DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } @@ -588,6 +591,23 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rdd.rdd, beanClass) } + /** + * Applies a schema to an List of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + * @group dataframes + * @since 1.6.0 + */ + def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { + val attrSeq = getSchema(beanClass) + val className = beanClass.getName + val beanInfo = Introspector.getBeanInfo(beanClass) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) + } + + /** * :: Experimental :: * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. @@ -645,7 +665,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.toMap) + createExternalTable(tableName, source, options.asScala.toMap) } /** @@ -662,9 +682,10 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = None, source, temporary = false, @@ -672,7 +693,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -689,7 +710,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.toMap) + createExternalTable(tableName, source, schema, options.asScala.toMap) } /** @@ -707,9 +728,10 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = Some(schema), source, temporary = false, @@ -717,7 +739,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -725,7 +747,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), df.logicalPlan) + catalog.registerTable(TableIdentifier(tableName), df.logicalPlan) } /** @@ -739,7 +761,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(Seq(tableName)) + catalog.unregisterTable(TableIdentifier(tableName)) } /** @@ -802,8 +824,11 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) - DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) + table(SqlParser.parseTableIdentifier(tableName)) + } + + private def table(tableIdent: TableIdentifier): DataFrame = { + DataFrame(this, catalog.lookupRelation(tableIdent)) } /** @@ -854,77 +879,11 @@ class SQLContext(@transient val sparkContext: SparkContext) }.toArray } - protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext: SparkContext = self.sparkContext - - val sqlContext: SQLContext = self - - def codegenEnabled: Boolean = self.conf.codegenEnabled - - def unsafeEnabled: Boolean = self.conf.unsafeEnabled - - def numPartitions: Int = self.conf.numShufflePartitions - - def strategies: Seq[Strategy] = - experimental.extraStrategies ++ ( - DataSourceStrategy :: - DDLStrategy :: - TakeOrderedAndProject :: - HashAggregation :: - Aggregation :: - LeftSemiJoin :: - HashJoin :: - InMemoryScans :: - BasicOperators :: - CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) - - /** - * Used to build table scan operators where complex projection and filtering are done using - * separate physical operators. This function returns the given scan operator with Project and - * Filter nodes added only when needed. For example, a Project operator is only used when the - * final desired output requires complex expressions to be evaluated or when columns can be - * further eliminated out after filtering has been done. - * - * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized - * away by the filter pushdown optimization. - * - * The required attributes for both filtering and expression evaluation are passed to the - * provided `scanBuilder` function so that it can avoid unnecessary column materialization. - */ - def pruneFilterProject( - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - prunePushedDownFilters: Seq[Expression] => Seq[Expression], - scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) - - // Right now we still use a projection even if the only evaluation is applying an alias - // to a column. Since this is a no-op, it could be avoided. However, using this - // optimization with the current implementation would change the output schema. - // TODO: Decouple final output schema from expression evaluation so this copy can be - // avoided safely. - - if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && - filterSet.subsetOf(projectSet)) { - // When it is possible to just use column pruning to get the right projection and - // when the columns of this projection are enough to evaluate all filter conditions, - // just do a scan followed by a filter, with no extra project. - val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) - filterCondition.map(Filter(_, scan)).getOrElse(scan) - } else { - val scan = scanBuilder((projectSet ++ filterSet).toSeq) - Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) - } - } - } + @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") + protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) @transient - protected[sql] val planner = new SparkPlanner + protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) @@ -941,93 +900,9 @@ class SQLContext(@transient val sparkContext: SparkContext) ) } - protected[sql] def openSession(): SQLSession = { - detachSession() - val session = createSession() - tlSession.set(session) - - session - } - - protected[sql] def currentSession(): SQLSession = { - tlSession.get() - } - - protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] def detachSession(): Unit = { - tlSession.remove() - } - - protected[sql] def setSession(session: SQLSession): Unit = { - detachSession() - tlSession.set(session) - } - - protected[sql] class SQLSession { - // Note that this is a lazy val so we can override the default value in subclasses. - protected[sql] lazy val conf: SQLConf = new SQLConf - } - - /** - * :: DeveloperApi :: - * The primary workflow for executing relational queries using Spark. Designed to allow easy - * access to the intermediate phases of query execution for developers. - */ - @DeveloperApi - protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - - lazy val analyzed: LogicalPlan = analyzer.execute(logical) - lazy val withCachedData: LogicalPlan = { - assertAnalyzed() - cacheManager.useCachedData(analyzed) - } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) - - // TODO: Don't just pick the first one... - lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(self) - planner.plan(optimizedPlan).next() - } - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) - - /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[InternalRow] = executedPlan.execute() - - protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: Throwable => e.toString } - - def simpleString: String = - s"""== Physical Plan == - |${stringOrError(executedPlan)} - """.stripMargin.trim - - override def toString: String = { - def output = - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - - // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) - // however, the `toRdd` will cause the real execution, which is not what we want. - // We need to think about how to avoid the side effect. - s"""== Parsed Logical Plan == - |${stringOrError(logical)} - |== Analyzed Logical Plan == - |${stringOrError(output)} - |${stringOrError(analyzed)} - |== Optimized Logical Plan == - |${stringOrError(optimizedPlan)} - |== Physical Plan == - |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - |== RDD == - """.stripMargin.trim - } - } + @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") + protected[sql] class QueryExecution(logical: LogicalPlan) + extends sparkexecution.QueryExecution(this, logical) /** * Parses the data type in our internal string representation. The data type string should @@ -1076,33 +951,33 @@ class SQLContext(@transient val sparkContext: SparkContext) //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } @@ -1112,9 +987,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[DataFrame]] if no paths are passed in. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. + * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.parquet()", "1.4.0") + @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { @@ -1129,9 +1004,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String): DataFrame = { read.json(path) } @@ -1141,18 +1016,18 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { read.schema(schema).json(path) } /** * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(path) } @@ -1163,9 +1038,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** @@ -1174,9 +1049,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** @@ -1184,9 +1059,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1196,9 +1071,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1208,9 +1083,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1220,9 +1095,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1232,9 +1107,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * using the default data source configured by spark.sql.sources.default. * * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. + * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use read.load(path)", "1.4.0") + @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String): DataFrame = { read.load(path) } @@ -1244,8 +1119,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).load(path)", "1.4.0") + @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String, source: String): DataFrame = { read.format(source).load(path) } @@ -1256,8 +1132,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1269,7 +1147,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1282,7 +1161,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() @@ -1296,7 +1176,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() } @@ -1306,9 +1187,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * url named table. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String): DataFrame = { read.jdbc(url, table, new Properties) } @@ -1324,9 +1205,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc( url: String, table: String, @@ -1344,9 +1225,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * of the [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { read.jdbc(url, table, theParts, new Properties) } @@ -1361,45 +1242,142 @@ class SQLContext(@transient val sparkContext: SparkContext) // Register a succesfully instantiatd context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. - SQLContext.setLastInstantiatedContext(self) + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + SQLContext.clearInstantiatedContext() + SQLContext.clearSqlListener() + } + }) + + SQLContext.setInstantiatedContext(self) } /** * This SQLContext object contains utility functions to create a singleton SQLContext instance, - * or to get the last created SQLContext instance. + * or to get the created SQLContext instance. + * + * It also provides utility functions to support preference for threads in multiple sessions + * scenario, setActive could set a SQLContext for current thread, which will be returned by + * getOrCreate instead of the global one. */ object SQLContext { - private val INSTANTIATION_LOCK = new Object() + /** + * The active SQLContext for the current thread. + */ + private val activeContext: InheritableThreadLocal[SQLContext] = + new InheritableThreadLocal[SQLContext] /** - * Reference to the last created SQLContext. + * Reference to the created SQLContext. */ - @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]() + @transient private val instantiatedContext = new AtomicReference[SQLContext]() + + @transient private val sqlListener = new AtomicReference[SQLListener]() /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. + * * This function can be used to create a singleton SQLContext object that can be shared across * the JVM. + * + * If there is an active SQLContext for current thread, it will be returned instead of the global + * one. + * + * @since 1.5.0 */ def getOrCreate(sparkContext: SparkContext): SQLContext = { - INSTANTIATION_LOCK.synchronized { - if (lastInstantiatedContext.get() == null) { + val ctx = activeContext.get() + if (ctx != null && !ctx.sparkContext.isStopped) { + return ctx + } + + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) + } else { + ctx } } - lastInstantiatedContext.get() } - private[sql] def clearLastInstantiatedContext(): Unit = { - INSTANTIATION_LOCK.synchronized { - lastInstantiatedContext.set(null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) + } + + private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } } } - private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = { - INSTANTIATION_LOCK.synchronized { - lastInstantiatedContext.set(sqlContext) + private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { + Option(instantiatedContext.get()) + } + + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + + /** + * Changes the SQLContext that will be returned in this thread and its children when + * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives + * a SQLContext with an isolated session, instead of the global (first created) context. + * + * @since 1.6.0 + */ + def setActive(sqlContext: SQLContext): Unit = { + activeContext.set(sqlContext) + } + + /** + * Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will + * return the first created context instead of a thread-local override. + * + * @since 1.6.0 + */ + def clearActive(): Unit = { + activeContext.remove() + } + + private[sql] def getActive(): Option[SQLContext] = { + Option(activeContext.get()) + } + + /** + * Converts an iterator of Java Beans to InternalRow using the provided + * bean info & schema. This is not related to the singleton, but is a static + * method for internal use. + */ + private def beansToRows(data: Iterator[_], beanInfo: BeanInfo, attrs: Seq[AttributeReference]): + Iterator[InternalRow] = { + val extractors = + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => + (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + } + data.map{ element => + new GenericInternalRow( + methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }.toArray[Any] + ): InternalRow + } + } + + /** + * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. + */ + private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 000000000000..6735d02954b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 + */ +abstract class SQLImplicits { + + protected def _sqlContext: SQLContext + + /** @since 1.6.0 */ + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + /** @since 1.6.0 */ + + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + + /** @since 1.6.0 */ + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + + /** + * Creates a [[Dataset]] from an RDD. + * @since 1.6.0 + */ + implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { + DatasetHolder(_sqlContext.createDataset(rdd)) + } + + /** + * Creates a [[Dataset]] from a local Seq. + * @since 1.6.0 + */ + implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { + DatasetHolder(_sqlContext.createDataset(s)) + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions for primitive types. + // They are likely to break source compatibility by making existing implicit conversions + // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7cd7421a518c..051694c0d43a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,6 +26,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types.DataType /** @@ -52,6 +54,21 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { functionRegistry.registerFunction(name, udf.builder) } + /** + * Register a user-defined aggregate function (UDAF). + * + * @param name the name of the UDAF. + * @param udaf the UDAF needs to be registered. + * @return the registered UDAF. + */ + def register( + name: String, + udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) + functionRegistry.registerFunction(name, builder) + udaf + } + // scalastyle:off /* register 0-22 were generated by this script @@ -71,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try($inputTypes).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -103,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -116,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -129,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -142,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -155,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -168,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -181,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -194,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -207,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -220,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -233,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -246,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -259,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -272,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -285,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -298,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -311,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -324,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -337,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -350,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -363,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -376,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -389,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 92861ab038f1..b3f134614c6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -22,11 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} +import scala.util.matching.Regex + private[r] object SQLUtils { + SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) + def createSQLContext(jsc: JavaSparkContext): SQLContext = { new SQLContext(jsc) } @@ -35,14 +39,15 @@ private[r] object SQLUtils { new JavaSparkContext(sqlCtx.sparkContext) } - def toSeq[T](arr: Array[T]): Seq[T] = { - arr.toSeq - } - def createStructType(fields : Seq[StructField]): StructType = { StructType(fields) } + // Support using regex in string interpolation + private[this] implicit class RegexContext(sc: StringContext) { + def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) + } + def getSQLDataType(dataType: String): DataType = { dataType match { case "byte" => org.apache.spark.sql.types.ByteType @@ -58,6 +63,27 @@ private[r] object SQLUtils { case "boolean" => org.apache.spark.sql.types.BooleanType case "timestamp" => org.apache.spark.sql.types.TimestampType case "date" => org.apache.spark.sql.types.DateType + case r"\Aarray<(.+)${elemType}>\Z" => + org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) + case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + case r"\Astruct<(.+)${fieldsStr}>\Z" => + if (fieldsStr(fieldsStr.length - 1) == ',') { + throw new IllegalArgumentException(s"Invaid type $dataType") + } + val fields = fieldsStr.split(",") + val structFields = fields.map { field => + field match { + case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => + createStructField(fieldName, fieldType, true) + + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + createStructType(structFields) case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } @@ -98,46 +124,24 @@ private[r] object SQLUtils { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - SerDe.writeInt(dos, row.length) - (0 until row.length).map { idx => - val obj: Object = row(idx).asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } + val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray + SerDe.writeObject(dos, cols) bos.toByteArray() } - def dfToCols(df: DataFrame): Array[Array[Byte]] = { - // localDF is Array[Row] - val localDF = df.collect() + def dfToCols(df: DataFrame): Array[Array[Any]] = { + val localDF: Array[Row] = df.collect() val numCols = df.columns.length - // dfCols is Array[Array[Any]] - val dfCols = convertRowsToColumns(localDF, numCols) - - dfCols.map { col => - colToRBytes(col) - } - } + val numRows = localDF.length - def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { - (0 until numCols).map { colIdx => - localDF.map { row => - row(colIdx) + val colArray = new Array[Array[Any]](numCols) + for (colNo <- 0 until numCols) { + colArray(colNo) = new Array[Any](numRows) + for (rowNo <- 0 until numRows) { + colArray(colNo)(rowNo) = localDF(rowNo)(colNo) } - }.toArray - } - - def colToRBytes(col: Array[Any]): Array[Byte] = { - val numRows = col.length - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - SerDe.writeInt(dos, numRows) - - col.map { item => - val obj: Object = item.asInstanceOf[Object] - SerDe.writeObject(dos, obj) } - bos.toByteArray() + colArray } def saveMode(mode: String): SaveMode = { @@ -163,4 +167,27 @@ private[r] object SQLUtils { options: java.util.Map[String, String]): DataFrame = { sqlContext.read.format(source).schema(schema).options(options).load() } + + def readSqlObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 's' => + // Read StructType for DataFrame + val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] + Row.fromSeq(fields) + case _ => null + } + } + + def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = { + obj match { + // Handle struct type in DataFrame + case v: GenericRowWithSchema => + dos.writeByte('s') + SerDe.writeObject(dos, v.schema.fieldNames) + SerDe.writeObject(dos, v.values) + true + case _ => + false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala deleted file mode 100644 index 4c29a093218a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor -import org.apache.spark.sql.types._ - -/** - * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is - * extracted from the buffer, instead of directly returning it, the value is set into some field of - * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods - * for primitive values provided by [[MutableRow]]. - */ -private[sql] trait ColumnAccessor { - initialize() - - protected def initialize() - - def hasNext: Boolean - - def extractTo(row: MutableRow, ordinal: Int) - - protected def underlyingBuffer: ByteBuffer -} - -private[sql] abstract class BasicColumnAccessor[JvmType]( - protected val buffer: ByteBuffer, - protected val columnType: ColumnType[JvmType]) - extends ColumnAccessor { - - protected def initialize() {} - - override def hasNext: Boolean = buffer.hasRemaining - - override def extractTo(row: MutableRow, ordinal: Int): Unit = { - extractSingle(row, ordinal) - } - - def extractSingle(row: MutableRow, ordinal: Int): Unit = { - columnType.extract(buffer, row, ordinal) - } - - protected def underlyingBuffer = buffer -} - -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( - override protected val buffer: ByteBuffer, - override protected val columnType: NativeColumnType[T]) - extends BasicColumnAccessor(buffer, columnType) - with NullableColumnAccessor - with CompressibleColumnAccessor[T] - -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BOOLEAN) - -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, SHORT) - -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) - -private[sql] class LongColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, LONG) - -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, FLOAT) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - -private[sql] class StringColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, STRING) - -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) - with NullableColumnAccessor - -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) - -private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType) - extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType)) - with NullableColumnAccessor - -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - -private[sql] object ColumnAccessor { - def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { - val dup = buffer.duplicate().order(ByteOrder.nativeOrder) - - // The first 4 bytes in the buffer indicate the column type. This field is not used now, - // because we always know the data type of the column ahead of time. - dup.getInt() - - dataType match { - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) - case IntegerType => new IntColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case LongType => new LongColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) - case FloatType => new FloatColumnAccessor(dup) - case DoubleType => new DoubleColumnAccessor(dup) - case StringType => new StringColumnAccessor(dup) - case BinaryType => new BinaryColumnAccessor(dup) - case DecimalType.Fixed(precision, scale) if precision < 19 => - new FixedDecimalColumnAccessor(dup, precision, scale) - case other => new GenericColumnAccessor(dup, other) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala deleted file mode 100644 index 531a8244d55d..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ /dev/null @@ -1,477 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import java.nio.ByteBuffer - -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * An abstract class that represents type of a column. Used to append/extract Java objects into/from - * the underlying [[ByteBuffer]] of a column. - * - * @tparam JvmType Underlying Java type to represent the elements. - */ -private[sql] sealed abstract class ColumnType[JvmType] { - - // The catalyst data type of this column. - def dataType: DataType - - // A unique ID representing the type. - def typeId: Int - - // Default size in bytes for one element of type T (e.g. 4 for `Int`). - def defaultSize: Int - - /** - * Extracts a value out of the buffer at the buffer's current position. - */ - def extract(buffer: ByteBuffer): JvmType - - /** - * Extracts a value out of the buffer at the buffer's current position and stores in - * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever - * possible. - */ - def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - setField(row, ordinal, extract(buffer)) - } - - /** - * Appends the given value v of type T into the given ByteBuffer. - */ - def append(v: JvmType, buffer: ByteBuffer): Unit - - /** - * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this - * method to avoid boxing/unboxing costs whenever possible. - */ - def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - append(getField(row, ordinal), buffer) - } - - /** - * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable - * length types such as byte arrays and strings. - */ - def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize - - /** - * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs - * whenever possible. - */ - def getField(row: InternalRow, ordinal: Int): JvmType - - /** - * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing - * costs whenever possible. - */ - def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit - - /** - * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid - * boxing/unboxing costs whenever possible. - */ - def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.update(toOrdinal, from.get(fromOrdinal, dataType)) - } - - /** - * Creates a duplicated copy of the value. - */ - def clone(v: JvmType): JvmType = v - - override def toString: String = getClass.getSimpleName.stripSuffix("$") -} - -private[sql] abstract class NativeColumnType[T <: AtomicType]( - val dataType: T, - val typeId: Int, - val defaultSize: Int) - extends ColumnType[T#InternalType] { - - /** - * Scala TypeTag. Can be used to create primitive arrays and hash tables. - */ - def scalaTag: TypeTag[dataType.InternalType] = dataType.tag -} - -private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - override def append(v: Int, buffer: ByteBuffer): Unit = { - buffer.putInt(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.putInt(row.getInt(ordinal)) - } - - override def extract(buffer: ByteBuffer): Int = { - buffer.getInt() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setInt(ordinal, buffer.getInt()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { - row.setInt(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setInt(toOrdinal, from.getInt(fromOrdinal)) - } -} - -private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { - override def append(v: Long, buffer: ByteBuffer): Unit = { - buffer.putLong(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.putLong(row.getLong(ordinal)) - } - - override def extract(buffer: ByteBuffer): Long = { - buffer.getLong() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setLong(ordinal, buffer.getLong()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { - row.setLong(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setLong(toOrdinal, from.getLong(fromOrdinal)) - } -} - -private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { - override def append(v: Float, buffer: ByteBuffer): Unit = { - buffer.putFloat(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.putFloat(row.getFloat(ordinal)) - } - - override def extract(buffer: ByteBuffer): Float = { - buffer.getFloat() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setFloat(ordinal, buffer.getFloat()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { - row.setFloat(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) - } -} - -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { - override def append(v: Double, buffer: ByteBuffer): Unit = { - buffer.putDouble(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.putDouble(row.getDouble(ordinal)) - } - - override def extract(buffer: ByteBuffer): Double = { - buffer.getDouble() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setDouble(ordinal, buffer.getDouble()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { - row.setDouble(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) - } -} - -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { - override def append(v: Boolean, buffer: ByteBuffer): Unit = { - buffer.put(if (v) 1: Byte else 0: Byte) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) - } - - override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setBoolean(ordinal, buffer.get() == 1) - } - - override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { - row.setBoolean(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) - } -} - -private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { - override def append(v: Byte, buffer: ByteBuffer): Unit = { - buffer.put(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.put(row.getByte(ordinal)) - } - - override def extract(buffer: ByteBuffer): Byte = { - buffer.get() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setByte(ordinal, buffer.get()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { - row.setByte(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setByte(toOrdinal, from.getByte(fromOrdinal)) - } -} - -private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { - override def append(v: Short, buffer: ByteBuffer): Unit = { - buffer.putShort(v) - } - - override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - buffer.putShort(row.getShort(ordinal)) - } - - override def extract(buffer: ByteBuffer): Short = { - buffer.getShort() - } - - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setShort(ordinal, buffer.getShort()) - } - - override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { - row.setShort(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.setShort(toOrdinal, from.getShort(fromOrdinal)) - } -} - -private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getUTF8String(ordinal).numBytes() + 4 - } - - override def append(v: UTF8String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes - buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) - } - - override def extract(buffer: ByteBuffer): UTF8String = { - val length = buffer.getInt() - val stringBytes = new Array[Byte](length) - buffer.get(stringBytes, 0, length) - UTF8String.fromBytes(stringBytes) - } - - override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { - row.update(ordinal, value.clone()) - } - - override def getField(row: InternalRow, ordinal: Int): UTF8String = { - row.getUTF8String(ordinal) - } - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - setField(to, toOrdinal, getField(from, fromOrdinal)) - } -} - -private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { - override def extract(buffer: ByteBuffer): Int = { - buffer.getInt - } - - override def append(v: Int, buffer: ByteBuffer): Unit = { - buffer.putInt(v) - } - - override def getField(row: InternalRow, ordinal: Int): Int = { - row.getInt(ordinal) - } - - def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { - row(ordinal) = value - } -} - -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { - override def extract(buffer: ByteBuffer): Long = { - buffer.getLong - } - - override def append(v: Long, buffer: ByteBuffer): Unit = { - buffer.putLong(v) - } - - override def getField(row: InternalRow, ordinal: Int): Long = { - row.getLong(ordinal) - } - - override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { - row(ordinal) = value - } -} - -private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) - extends NativeColumnType( - DecimalType(precision, scale), - 10, - FIXED_DECIMAL.defaultSize) { - - override def extract(buffer: ByteBuffer): Decimal = { - Decimal(buffer.getLong(), precision, scale) - } - - override def append(v: Decimal, buffer: ByteBuffer): Unit = { - buffer.putLong(v.toUnscaledLong) - } - - override def getField(row: InternalRow, ordinal: Int): Decimal = { - row.getDecimal(ordinal, precision, scale) - } - - override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { - row.setDecimal(ordinal, value, precision) - } - - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - setField(to, toOrdinal, getField(from, fromOrdinal)) - } -} - -private[sql] object FIXED_DECIMAL { - val defaultSize = 8 -} - -private[sql] sealed abstract class ByteArrayColumnType( - val typeId: Int, - val defaultSize: Int) - extends ColumnType[Array[Byte]] { - - override def actualSize(row: InternalRow, ordinal: Int): Int = { - getField(row, ordinal).length + 4 - } - - override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { - buffer.putInt(v.length).put(v, 0, v.length) - } - - override def extract(buffer: ByteBuffer): Array[Byte] = { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - bytes - } -} - -private[sql] object BINARY extends ByteArrayColumnType(11, 16) { - - def dataType: DataType = BooleanType - - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row.update(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row.getBinary(ordinal) - } -} - -// Used to process generic objects (all types other than those listed above). Objects should be -// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized -// byte array. -private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(12, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row.update(ordinal, SparkSqlSerializer.deserialize[Any](value)) - } - - override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row.get(ordinal, dataType)) - } -} - -private[sql] object ColumnType { - def apply(dataType: DataType): ColumnType[_] = { - dataType match { - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT - case IntegerType => INT - case DateType => DATE - case LongType => LONG - case TimestampType => TIMESTAMP - case FloatType => FLOAT - case DoubleType => DOUBLE - case StringType => STRING - case BinaryType => BINARY - case DecimalType.Fixed(precision, scale) if precision < 19 => - FIXED_DECIMAL(precision, scale) - case other => GENERIC(other) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala deleted file mode 100644 index e8c6a0f8f801..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ - -/** - * :: DeveloperApi :: - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -@DeveloperApi -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression1, - aggregate: AggregateExpression1, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression1 => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { - val buffer = new Array[AggregateFunction1](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): InternalRow = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d3e5c378d037..50f6562815c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -21,8 +21,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -37,7 +36,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe * * Internal to Spark SQL. */ -private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { +private[sql] class CacheManager extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -45,15 +44,6 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { @transient private val cacheLock = new ReentrantReadWriteLock - /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty - - /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) - - /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) - /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { val lock = cacheLock.readLock() @@ -84,18 +74,19 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Queryable]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { + val sqlContext = query.sqlContext cachedData += CachedData( planToCache, @@ -103,13 +94,13 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -117,9 +108,11 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[Queryable]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -131,12 +124,12 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Queryable]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala new file mode 100644 index 000000000000..663bc904f39c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering + +/** + * Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a + * grouping key with its associated values from all [[GroupedIterator]]s. + * Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key. + */ +class CoGroupedIterator( + left: Iterator[(InternalRow, Iterator[InternalRow])], + right: Iterator[(InternalRow, Iterator[InternalRow])], + groupingSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] { + + private val keyOrdering = + GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema) + + private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _ + private var currentRightData: (InternalRow, Iterator[InternalRow]) = _ + + override def hasNext: Boolean = { + if (currentLeftData == null && left.hasNext) { + currentLeftData = left.next() + } + if (currentRightData == null && right.hasNext) { + currentRightData = right.next() + } + + currentLeftData != null || currentRightData != null + } + + override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + assert(hasNext) + + if (currentLeftData.eq(null)) { + // left is null, right is not null, consume the right data. + rightOnly() + } else if (currentRightData.eq(null)) { + // left is not null, right is null, consume the left data. + leftOnly() + } else if (currentLeftData._1 == currentRightData._1) { + // left and right have the same grouping key, consume both of them. + val result = (currentLeftData._1, currentLeftData._2, currentRightData._2) + currentLeftData = null + currentRightData = null + result + } else { + val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1) + assert(compare != 0) + if (compare < 0) { + // the grouping key of left is smaller, consume the left data. + leftOnly() + } else { + // the grouping key of right is smaller, consume the right data. + rightOnly() + } + } + } + + private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentLeftData._1, currentLeftData._2, Iterator.empty) + currentLeftData = null + result + } + + private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentRightData._1, Iterator.empty, currentRightData._2) + currentRightData = null + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05b009d1935b..62cbc518e02a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,42 +17,58 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi +import java.util.Random + +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair -import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} /** - * :: DeveloperApi :: * Performs a shuffle that will result in the desired `newPartitioning`. */ -@DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + var newPartitioning: Partitioning, + child: SparkPlan, + @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { - override def outputPartitioning: Partitioning = newPartitioning + override def nodeName: String = { + val extraInfo = coordinator match { + case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" + } - override def output: Seq[Attribute] = child.output + val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + s"$simpleNodeName$extraInfo" + } - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + /** + * Returns true iff we can support the data type, and we are not doing range partitioning. + */ + private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] - override def canProcessSafeRows: Boolean = true + override def outputPartitioning: Partitioning = newPartitioning - override def canProcessUnsafeRows: Boolean = { - // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to - // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to - // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !newPartitioning.isInstanceOf[RangePartitioning] - } + override def output: Seq[Attribute] = child.output + + // This setting is somewhat counterintuitive: + // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, + // so the planner inserts a converter to convert data into UnsafeRow if needed. + override def outputsUnsafeRows: Boolean = tungstenMode + override def canProcessSafeRows: Boolean = !tungstenMode + override def canProcessUnsafeRows: Boolean = tungstenMode /** * Determines whether records must be defensively copied before being sent to the shuffle. @@ -81,10 +97,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf val shuffleManager = SparkEnv.get.shuffleManager - val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || - shuffleManager.isInstanceOf[UnsafeShuffleManager] + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { @@ -93,22 +107,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { - // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting - // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an aggregator or ordering and the record serializer - // has certain properties. If this optimization is enabled, we can safely avoid the copy. + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. // // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only // need to check whether the optimization is enabled and supported by our serializer. - // - // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code - // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls - // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In - // both cases, we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. true } } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { @@ -123,40 +133,49 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf private val serializer: Serializer = { - val rowDataTypes = child.output.map(_.dataType).toArray - // It is true when there is no field that needs to be write out. - // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = rowDataTypes == null || rowDataTypes.length == 0 - - val useSqlSerializer2 = - child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. - !noField - - if (child.outputsUnsafeRows) { - logInfo("Using UnsafeRowSerializer.") + if (tungstenMode) { new UnsafeRowSerializer(child.output.size) - } else if (useSqlSerializer2) { - logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(rowDataTypes) } else { - logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } } - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { val rdd = child.execute() val part: Partitioner = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = rdd.mapPartitions { iter => + val rddForSampling = rdd.mapPartitionsInternal { iter => val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be + // serialized and this ordering needs to be created on the driver in order to be passed into + // Spark core code. + implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { @@ -166,26 +185,81 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } - def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { + def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + (row: InternalRow) => { + // The HashPartitioner will handle the `mod` by the number of partitions + position += 1 + position + } case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { if (needToCopyObjectsBeforeShuffle(part, serializer)) { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } } } } - new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions) + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + Some(serializer)) + + dependency + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = UnknownPartitioning(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } +} + +object Exchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } } @@ -194,66 +268,226 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * required input partition ordering requirements are met. + * input partition ordering requirements are met. */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - // TODO: Determine the number of partitions. - def numPartitions: Int = sqlContext.conf.numShufflePartitions + private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator: SparkPlan => - // Adds Exchange or Sort operators as required - def addOperatorsIfNecessary( - partitioning: Partitioning, - rowOrdering: Seq[SortOrder], - child: SparkPlan): SparkPlan = { - - def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (!child.outputPartitioning.guarantees(partitioning)) { - Exchange(partitioning, child) - } else { - child - } + private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + + private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + + private def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def createPartitioning( + requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } + + /** + * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled + * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]]. + */ + private def withExchangeCoordinator( + children: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + val supportsCoordinator = + if (children.exists(_.isInstanceOf[Exchange])) { + // Right now, ExchangeCoordinator only support HashPartitionings. + children.forall { + case e @ Exchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withCoordinator = + if (adaptiveExecutionEnabled && supportsCoordinator) { + val coordinator = + new ExchangeCoordinator( + children.length, + targetPostShuffleInputSize, + minNumPostShufflePartitions) + children.zip(requiredChildDistributions).map { + case (e: Exchange, _) => + // This child is an Exchange, we need to add the coordinator. + e.copy(coordinator = Some(coordinator)) + case (child, distribution) => + // If this child is not an Exchange, we need to add an Exchange for now. + // Ideally, we can try to avoid this Exchange. However, when we reach here, + // there are at least two children operators (because if there is a single child + // and we can avoid Exchange, supportsCoordinator will be false and we + // will not reach here.). Although we can make two children have the same number of + // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. + // For example, let's say we have the following plan + // Join + // / \ + // Agg Exchange + // / \ + // Exchange t2 + // / + // t1 + // In this case, because a post-shuffle partition can include multiple pre-shuffle + // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes + // after shuffle. So, even we can use the child Exchange operator of the Join to + // have a number of post-shuffle partitions that matches the number of partitions of + // Agg, we cannot say these two children are partitioned in the same way. + // Here is another case + // Join + // / \ + // Agg1 Agg2 + // / \ + // Exchange1 Exchange2 + // / \ + // t1 t2 + // In this case, two Aggs shuffle data with the same column of the join condition. + // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same + // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 + // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle + // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its + // pre-shuffle partitions by using another partitionStartIndices [0, 4]. + // So, Agg1 and Agg2 are actually not co-partitioned. + // + // It will be great to introduce a new Partitioning to represent the post-shuffle + // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) + assert(targetPartitioning.isInstanceOf[HashPartitioning]) + Exchange(targetPartitioning, child, Some(coordinator)) } + } else { + // If we do not need ExchangeCoordinator, the original children are returned. + children + } - def addSortIfNecessary(child: SparkPlan): SparkPlan = { + withCoordinator + } - if (rowOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. - val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min - if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - } else { - child + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { case (child, distribution) => + if (child.outputPartitioning.satisfies(distribution)) { + child + } else { + Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + } + } + + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + if (children.length > 1 + && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + } + + children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. + children + } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) } - } else { - child } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions } - addSortIfNecessary(addShuffleIfNecessary(child)) + children.zip(requiredChildDistributions).map { + case (child, distribution) => { + val targetPartitioning = + createPartitioning(distribution, numPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case Exchange(_, c, _) => Exchange(targetPartitioning, c) + case _ => Exchange(targetPartitioning, child) + } + } + } + } } + } - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + // Now, we need to add ExchangeCoordinator if necessary. + // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. + // However, with the way that we plan the query, we do not have a place where we have a + // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator + // at here for now. + // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, + // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. + children = withExchangeCoordinator(children, requiredChildDistributions) - case (UnspecifiedDistribution, Seq(), child) => + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + Sort(requiredOrdering, global = false, child = child) + } else { child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") + } + } else { + child } + } - operator.withNewChildren(fixedChildren) + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => ensureDistributionAndOrdering(operator) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala new file mode 100644 index 000000000000..827fdd278460 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.{Map => JMap, HashMap => JHashMap} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, MapOutputStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. + * Right now, the work of this coordinator is to determine the number of post-shuffle partitions + * for a stage that needs to fetch shuffle data from one or multiple stages. + * + * A coordinator is constructed with three parameters, `numExchanges`, + * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. + * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to + * this coordinator. So, when we start to do any actual work, we have a way to make sure that + * we have got expected number of [[Exchange]]s. + * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's + * input data size. With this parameter, we can estimate the number of post-shuffle partitions. + * This parameter is configured through + * `spark.sql.adaptive.shuffle.targetPostShuffleInputSize`. + * - `minNumPostShufflePartitions` is an optional parameter. If it is defined, this coordinator + * will try to make sure that there are at least `minNumPostShufflePartitions` post-shuffle + * partitions. + * + * The workflow of this coordinator is described as follows: + * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator, + * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. + * This happens in the `doPrepare` method. + * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will + * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]]. + * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will + * immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. + * - If this coordinator has not made the decision on how to shuffle data, it will ask those + * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size + * statistics of pre-shuffle partitions, this coordinator will determine the number of + * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices + * to a single post-shuffle partition whenever necessary. + * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered + * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can + * lookup the corresponding [[RDD]]. + * + * The strategy used to determine the number of post-shuffle partitions is described as follows. + * To determine the number of post-shuffle partitions, we have a target input size for a + * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages + * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and + * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until + * the size of a post-shuffle partition is equal or greater than the target size. + * For example, we have two stages with the following pre-shuffle partition size statistics: + * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] + * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] + * assuming the target input size is 128 MB, we will have three post-shuffle partitions, + * which are: + * - post-shuffle partition 0: pre-shuffle partition 0 and 1 + * - post-shuffle partition 1: pre-shuffle partition 2 + * - post-shuffle partition 2: pre-shuffle partition 3 and 4 + */ +private[sql] class ExchangeCoordinator( + numExchanges: Int, + advisoryTargetPostShuffleInputSize: Long, + minNumPostShufflePartitions: Option[Int] = None) + extends Logging { + + // The registered Exchange operators. + private[this] val exchanges = ArrayBuffer[Exchange]() + + // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. + private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] = + new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + + // A boolean that indicates if this coordinator has made decision on how to shuffle data. + // This variable will only be updated by doEstimationIfNecessary, which is protected by + // synchronized. + @volatile private[this] var estimated: Boolean = false + + /** + * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be + * called in the `doPrepare` method of an [[Exchange]] operator. + */ + @GuardedBy("this") + def registerExchange(exchange: Exchange): Unit = synchronized { + exchanges += exchange + } + + def isEstimated: Boolean = estimated + + /** + * Estimates partition start indices for post-shuffle partitions based on + * mapOutputStatistics provided by all pre-shuffle stages. + */ + private[sql] def estimatePartitionStartIndices( + mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit + // a stage when the number of partitions of this dependency is 0. + assert(mapOutputStatistics.length <= numExchanges) + + // If minNumPostShufflePartitions is defined, it is possible that we need to use a + // value less than advisoryTargetPostShuffleInputSize as the target input size of + // a post shuffle task. + val targetPostShuffleInputSize = minNumPostShufflePartitions match { + case Some(numPartitions) => + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we + // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. + val maxPostShuffleInputSize = + math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) + math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) + + case None => advisoryTargetPostShuffleInputSize + } + + logInfo( + s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + + s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + + // Make sure we do get the same number of pre-shuffle partitions for those stages. + val distinctNumPreShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). + assert( + distinctNumPreShufflePartitions.length == 1, + "There should be only one distinct value of the number pre-shuffle partitions " + + "among registered Exchange operator.") + val numPreShufflePartitions = distinctNumPreShufflePartitions.head + + val partitionStartIndices = ArrayBuffer[Int]() + // The first element of partitionStartIndices is always 0. + partitionStartIndices += 0 + + var postShuffleInputSize = 0L + + var i = 0 + while (i < numPreShufflePartitions) { + // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. + // Then, we add the total size to postShuffleInputSize. + var j = 0 + while (j < mapOutputStatistics.length) { + postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If the current postShuffleInputSize is equal or greater than the + // targetPostShuffleInputSize, We need to add a new element in partitionStartIndices. + if (postShuffleInputSize >= targetPostShuffleInputSize) { + if (i < numPreShufflePartitions - 1) { + // Next start index. + partitionStartIndices += i + 1 + } else { + // This is the last element. So, we do not need to append the next start index to + // partitionStartIndices. + } + // reset postShuffleInputSize. + postShuffleInputSize = 0L + } + + i += 1 + } + + partitionStartIndices.toArray + } + + @GuardedBy("this") + private def doEstimationIfNecessary(): Unit = synchronized { + // It is unlikely that this method will be called from multiple threads + // (when multiple threads trigger the execution of THIS physical) + // because in common use cases, we will create new physical plan after + // users apply operations (e.g. projection) to an existing DataFrame. + // However, if it happens, we have synchronized to make sure only one + // thread will trigger the job submission. + if (!estimated) { + // Make sure we have the expected number of registered Exchange operators. + assert(exchanges.length == numExchanges) + + val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + + // Submit all map stages + val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() + val submittedStageFutures = ArrayBuffer[SimpleFutureAction[MapOutputStatistics]]() + var i = 0 + while (i < numExchanges) { + val exchange = exchanges(i) + val shuffleDependency = exchange.prepareShuffleDependency() + shuffleDependencies += shuffleDependency + if (shuffleDependency.rdd.partitions.length != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + submittedStageFutures += + exchange.sqlContext.sparkContext.submitMapStage(shuffleDependency) + } + i += 1 + } + + // Wait for the finishes of those submitted map stages. + val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) + var j = 0 + while (j < submittedStageFutures.length) { + // This call is a blocking call. If the stage has not finished, we will wait at here. + mapOutputStatistics(j) = submittedStageFutures(j).get() + j += 1 + } + + // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the + // number of post-shuffle partitions. + val partitionStartIndices = + if (mapOutputStatistics.length == 0) { + None + } else { + Some(estimatePartitionStartIndices(mapOutputStatistics)) + } + + var k = 0 + while (k < numExchanges) { + val exchange = exchanges(k) + val rdd = + exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) + newPostShuffleRDDs.put(exchange, rdd) + + k += 1 + } + + // Finally, we set postShuffleRDDs and estimated. + assert(postShuffleRDDs.isEmpty) + assert(newPostShuffleRDDs.size() == numExchanges) + postShuffleRDDs.putAll(newPostShuffleRDDs) + estimated = true + } + } + + def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = { + doEstimationIfNecessary() + + if (!postShuffleRDDs.containsKey(exchange)) { + throw new IllegalStateException( + s"The given $exchange is not registered in this coordinator.") + } + + postShuffleRDDs.get(exchange) + } + + override def toString: String = { + s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index da27a753a710..b8a43025882e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -17,19 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} -/** - * :: DeveloperApi :: - */ -@DeveloperApi + object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => @@ -77,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] @@ -95,28 +96,32 @@ private[sql] case class LogicalRDD( /** Physical plan node for scanning data from an RDD. */ private[sql] case class PhysicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow]) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd -} - -/** Logical plan node for scanning data from a local collection. */ -private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - override def children: Seq[LogicalPlan] = Nil + rdd: RDD[InternalRow], + override val nodeName: String, + override val metadata: Map[String, String] = Map.empty, + override val outputsUnsafeRows: Boolean = false) + extends LeafNode { - override def newInstance(): this.type = - LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] + protected override def doExecute(): RDD[InternalRow] = rdd - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LogicalRDD(_, otherRDD) => rows == rows - case _ => false + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" } +} - @transient override lazy val statistics: Statistics = Statistics( - // TODO: Improve the statistics estimation. - // This is made small enough so it can be broadcasted. - sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 - ) +private[sql] object PhysicalRDD { + // Metadata keys + val INPUT_PATHS = "InputPaths" + val PUSHED_FILTERS = "PushedFilters" + + def createFromDataSource( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation, + metadata: Map[String, String] = Map.empty): PhysicalRDD = { + // All HadoopFsRelations output UnsafeRows + val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index d90cae1c4c06..91530bd63798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -32,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit * @param output The output Schema * @param child Child operator */ -@DeveloperApi case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], @@ -43,14 +41,24 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + + private[this] val projection = { + if (outputsUnsafeRows) { + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + } else { + (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() + } + } + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray - + val groups = projections.map(projection).toArray new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala new file mode 100644 index 000000000000..7a2a9eed5807 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * An interface for relations that are backed by files. When a class implements this interface, + * the list of paths that it returns will be returned to a user who calls `inputPaths` on any + * DataFrame that queries this relation. + */ +private[sql] trait FileRelation { + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c3c0dc441c92..54b8cb58285c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +34,6 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In } /** - * :: DeveloperApi :: * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with @@ -48,7 +46,6 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param output the output attributes of this node, which constructed in analysis phase, * and we can not change it, as the parent node bound with it already. */ -@DeveloperApi case class Generate( generator: Generator, join: Boolean, @@ -62,7 +59,7 @@ case class Generate( protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow @@ -82,7 +79,7 @@ case class Generate( } } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.flatMap(row => boundGenerator.eval(row)) ++ LazyIterator(() => boundGenerator.terminate()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala deleted file mode 100644 index cd87b8deba0c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io.IOException - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees._ -import org.apache.spark.sql.types._ - -case class AggregateEvaluation( - schema: Seq[Attribute], - initialValues: Seq[Expression], - update: Seq[Expression], - result: Expression) - -/** - * :: DeveloperApi :: - * Alternate version of aggregation that leverages projection and thus code generation. - * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto - * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. - * @param child the input data source. - */ -@DeveloperApi -case class GeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - unsafeEnabled: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = { - val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression1 => agg} - } - - // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite - // (in test "aggregation with codegen"). - val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val result = currentCount - - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case s @ Sum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - val updateFunction = Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType) - ) :: currentSum :: Nil) - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, s.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case m @ Max(expr) => - val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMax = MaxOf(currentMax, expr) - - AggregateEvaluation( - currentMax :: Nil, - initialValue :: Nil, - updateMax :: Nil, - currentMax) - - case m @ Min(expr) => - val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMin = MinOf(currentMin, expr) - - AggregateEvaluation( - currentMin :: Nil, - initialValue :: Nil, - updateMin :: Nil, - currentMin) - - case CollectHashSet(Seq(expr)) => - val set = - AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType - val set = - AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() - val initialValue = NewSet(elementType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) - - case o => sys.error(s"$o can't be codegened.") - } - - val computationSchema = computeFunctions.flatMap(_.schema) - - val resultMap: Map[TreeNodeRef, Expression] = - aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => new TreeNodeRef(agg) -> func.result - }.toMap - - val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne.toAttribute) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) - } - - // The set of expressions that produce the final output given the aggregation buffer and the - // grouping expressions. - val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression => - namedGroups.collectFirst { - case (expr, attr) if expr semanticEquals e => attr - }.getOrElse(e) - }) - - val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) - - val groupKeySchema: StructType = { - val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => - // This is a dummy field name - StructField(idx.toString, expr.dataType, expr.nullable) - } - StructType(fields) - } - - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - } - - child.execute().mapPartitions { iter => - // Builds a new custom class for holding the results of aggregation for a group. - val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) - log.info(s"Initial values: ${initialValues.mkString(",")}") - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) - log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") - - // A projection that is used to update the aggregate values for a group given a new tuple. - // This projection should be targeted at the current values for the group and then applied - // to a joined row of the current values with the new input row. - val updateExpressions = computeFunctions.flatMap(_.update) - val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output - val updateProjection = newMutableProjection(updateExpressions, updateSchema)() - log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") - - // A projection that produces the final result, given a computation. - val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - namedGroups.map(_._2) ++ computationSchema) - log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - - val joinedRow = new JoinedRow - - if (!iter.hasNext) { - // This is an empty input, so return early so that we do not allocate data structures - // that won't be cleaned up (see SPARK-8357). - if (groupingExpressions.isEmpty) { - // This is a global aggregate, so return an empty aggregation buffer. - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(newAggregationBuffer(EmptyRow))) - } else { - // This is a grouped aggregate, so return an empty iterator. - Iterator[InternalRow]() - } - } else if (groupingExpressions.isEmpty) { - // TODO: Codegening anything other than the updateProjection is probably over kill. - val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: InternalRow = null - updateProjection.target(buffer) - - while (iter.hasNext) { - currentRow = iter.next() - updateProjection(joinedRow(buffer, currentRow)) - } - - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(buffer)) - - } else if (unsafeEnabled && schemaSupportsUnsafe) { - assert(iter.hasNext, "There should be at least one row for this path") - log.info("Using Unsafe-based aggregator") - val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - pageSizeBytes, - false // disable tracking of performance metrics - ) - - while (iter.hasNext) { - val currentRow: InternalRow = iter.next() - val groupKey: InternalRow = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - if (aggregationBuffer == null) { - throw new IOException("Could not allocate memory to grow aggregation buffer") - } - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) - } - - new Iterator[InternalRow] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() - private[this] var _hasNext = mapIterator.next() - - def hasNext: Boolean = _hasNext - - def next(): InternalRow = { - if (_hasNext) { - val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue)) - _hasNext = mapIterator.next() - if (_hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy - } - } else { - throw new java.util.NoSuchElementException - } - } - } - } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[InternalRow, MutableRow]() - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupProjection(currentRow) - var currentBuffer = buffers.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - buffers.put(currentGroup, currentBuffer) - } - // Target the projection at the current aggregation buffer and then project the updated - // values. - updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) - } - - new Iterator[InternalRow] { - private[this] val resultIterator = buffers.entrySet.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = resultIterator.hasNext - - def next(): InternalRow = { - val currentGroup = resultIterator.next() - resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala new file mode 100644 index 000000000000..6a8850129f1a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} + +object GroupedIterator { + def apply( + input: Iterator[InternalRow], + keyExpressions: Seq[Expression], + inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { + if (input.hasNext) { + new GroupedIterator(input.buffered, keyExpressions, inputSchema) + } else { + Iterator.empty + } + } +} + +/** + * Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to + * next will return a pair containing the current group and an iterator that will return all the + * elements of that group. Iterators for each group are lazily constructed by extracting rows + * from the input iterator. As such, full groups are never materialized by this class. + * + * Example input: + * {{{ + * Input: [a, 1], [b, 2], [b, 3] + * Grouping: x#1 + * InputSchema: x#1, y#2 + * }}} + * + * Result: + * {{{ + * First call to next(): ([a], Iterator([a, 1]) + * Second call to next(): ([b], Iterator([b, 2], [b, 3]) + * }}} + * + * Note, the class does not handle the case of an empty input for simplicity of implementation. + * Use the factory to construct a new instance. + * + * @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or + * it is possible for the same group to appear more than once. + * @param groupingExpressions The set of expressions used to do grouping. The result of evaluating + * these expressions will be returned as the first part of each call + * to `next()`. + * @param inputSchema The schema of the rows in the `input` iterator. + */ +class GroupedIterator private( + input: BufferedIterator[InternalRow], + groupingExpressions: Seq[Expression], + inputSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow])] { + + /** Compares two input rows and returns 0 if they are in the same group. */ + val sortOrder = groupingExpressions.map(SortOrder(_, Ascending)) + val keyOrdering = GenerateOrdering.generate(sortOrder, inputSchema) + + /** Creates a row containing only the key for a given input row. */ + val keyProjection = GenerateUnsafeProjection.generate(groupingExpressions, inputSchema) + + /** + * Holds null or the row that will be returned on next call to `next()` in the inner iterator. + */ + var currentRow = input.next() + + /** Holds a copy of an input row that is in the current group. */ + var currentGroup = currentRow.copy() + + assert(keyOrdering.compare(currentGroup, currentRow) == 0) + var currentIterator = createGroupValuesIterator() + + /** + * Return true if we already have the next iterator or fetching a new iterator is successful. + * + * Note that, if we get the iterator by `next`, we should consume it before call `hasNext`, + * because we will consume the input data to skip to next group while fetching a new iterator, + * thus make the previous iterator empty. + */ + def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator + + def next(): (InternalRow, Iterator[InternalRow]) = { + assert(hasNext) // Ensure we have fetched the next iterator. + val ret = (keyProjection(currentGroup), currentIterator) + currentIterator = null + ret + } + + private def fetchNextGroupIterator(): Boolean = { + assert(currentIterator == null) + + if (currentRow == null && input.hasNext) { + currentRow = input.next() + } + + if (currentRow == null) { + // These is no data left, return false. + false + } else { + // Skip to next group. + while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { + currentRow = input.next() + } + + if (keyOrdering.compare(currentGroup, currentRow) == 0) { + // We are in the last group, there is no more groups, return false. + false + } else { + // Now the `currentRow` is the first row of next group. + currentGroup = currentRow.copy() + currentIterator = createGroupValuesIterator() + true + } + } + } + + private def createGroupValuesIterator(): Iterator[InternalRow] = { + new Iterator[InternalRow] { + def hasNext: Boolean = currentRow != null || fetchNextRowInGroup() + + def next(): InternalRow = { + assert(hasNext) + val res = currentRow + currentRow = null + res + } + + private def fetchNextRowInGroup(): Boolean = { + assert(currentRow == null) + + if (input.hasNext) { + // The inner iterator should NOT consume the input into next group, here we use `head` to + // peek the next input, to see if we should continue to process it. + if (keyOrdering.compare(currentGroup, input.head) == 0) { + // Next input is in the current group. Continue the inner iterator. + currentRow = input.next() + true + } else { + // Next input is not in the right group. End this inner iterator. + false + } + } else { + // There is no more data, return false. + false + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 34e926e4582b..ba7f6287ac6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute @@ -34,13 +33,11 @@ private[sql] case class LocalTableScan( protected override def doExecute(): RDD[InternalRow] = rdd - override def executeCollect(): Array[Row] = { - val converter = CatalystTypeConverters.createToScalaConverter(schema) - rows.map(converter(_).asInstanceOf[Row]).toArray + override def executeCollect(): Array[InternalRow] = { + rows.toArray } - override def executeTake(limit: Int): Array[Row] = { - val converter = CatalystTypeConverters.createToScalaConverter(schema) - rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray + override def executeTake(limit: Int): Array[InternalRow] = { + rows.take(limit).toArray } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala new file mode 100644 index 000000000000..107570f9dbcc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + * + * While this is not a public class, we should avoid changing the function names for the sake of + * changing them, because a lot of developers use the feature for debugging. + */ +class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + + def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) + + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() + sqlContext.cacheManager.useCachedData(analyzed) + } + + lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) + + lazy val sparkPlan: SparkPlan = { + SQLContext.setActive(sqlContext) + sqlContext.planner.plan(optimizedPlan).next() + } + + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + def simpleString: String = { + s"""== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + } + + override def toString: String = { + def output = + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") + + s"""== Parsed Logical Plan == + |${stringOrError(logical)} + |== Analyzed Logical Plan == + |${stringOrError(output)} + |${stringOrError(analyzed)} + |== Optimized Logical Plan == + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala new file mode 100644 index 000000000000..b397d42612cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.control.NonFatal + +import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.StructType + +/** A trait that holds shared code between DataFrames and Datasets. */ +private[sql] trait Queryable { + def schema: StructType + def queryExecution: QueryExecution + def sqlContext: SQLContext + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } + + def printSchema(): Unit + + def explain(extended: Boolean): Unit + + def explain(): Unit + + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String + + /** + * Format the string representing rows for output + * @param rows The rows to show + * @param numRows Number of rows to show + * @param hasMoreData Whether some rows are not shown due to the limit + * @param truncate Whether truncate long strings and align cells right + * + */ + private[sql] def formatString ( + rows: Seq[Seq[String]], + numRows: Int, + hasMoreData : Boolean, + truncate: Boolean = true): String = { + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala new file mode 100644 index 000000000000..7462dbc4eba3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + */ +private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ + def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ + def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var hasNextWasCalled: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + // Idempotency: + if (!hasNextWasCalled) { + _hasNext = rowIter.advanceNext() + hasNextWasCalled = true + } + _hasNext + } + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + hasNextWasCalled = false + rowIter.getRow + } +} + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala new file mode 100644 index 000000000000..34971986261c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} +import org.apache.spark.util.Utils + +private[sql] object SQLExecution { + + val EXECUTION_ID_KEY = "spark.sql.execution.id" + + private val _nextExecutionId = new AtomicLong(0) + + private def nextExecutionId: Long = _nextExecutionId.getAndIncrement + + /** + * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that + * we can connect them with an execution. + */ + def withNewExecutionId[T]( + sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = { + val sc = sqlContext.sparkContext + val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) + if (oldExecutionId == null) { + val executionId = SQLExecution.nextExecutionId + sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + val r = try { + val callSite = Utils.getCallSite() + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } + } finally { + sc.setLocalProperty(EXECUTION_ID_KEY, null) + } + r + } else { + // Don't support nested `withNewExecutionId`. This is an example of the nested + // `withNewExecutionId`: + // + // class DataFrame { + // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } + // } + // + // Note: `collect` will call withNewExecutionId + // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" + // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution + // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, + // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. + // + // A real case is the `DataFrame.count` method. + throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") + } + } + + /** + * Wrap an action with a known executionId. When running a different action in a different + * thread from the original one, this method can be used to connect the Spark jobs in this action + * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + */ + def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 88f5b13c8f24..42891287a300 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -17,15 +17,23 @@ package org.apache.spark.sql.execution +import java.util.Arrays + import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.DataType -private class ShuffledRowRDDPartition(val idx: Int) extends Partition { - override val index: Int = idx - override def hashCode(): Int = idx +/** + * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition + * (identified by `postShufflePartitionIndex`) contains a range of pre-shuffle partitions + * (`startPreShufflePartitionIndex` to `endPreShufflePartitionIndex - 1`, inclusive). + */ +private final class ShuffledRowRDDPartition( + val postShufflePartitionIndex: Int, + val startPreShufflePartitionIndex: Int, + val endPreShufflePartitionIndex: Int) extends Partition { + override val index: Int = postShufflePartitionIndex + override def hashCode(): Int = postShufflePartitionIndex } /** @@ -36,45 +44,130 @@ private class PartitionIdPassthrough(override val numPartitions: Int) extends Pa override def getPartition(key: Any): Int = key.asInstanceOf[Int] } +/** + * A Partitioner that might group together one or more partitions from the parent. + * + * @param parent a parent partitioner + * @param partitionStartIndices indices of partitions in parent that should create new partitions + * in child (this should be an array of increasing partition IDs). For example, if we have a + * parent with 5 partitions, and partitionStartIndices is [0, 2, 4], we get three output + * partitions, corresponding to partition ranges [0, 1], [2, 3] and [4] of the parent partitioner. + */ +class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: Array[Int]) + extends Partitioner { + + @transient private lazy val parentPartitionMapping: Array[Int] = { + val n = parent.numPartitions + val result = new Array[Int](n) + for (i <- 0 until partitionStartIndices.length) { + val start = partitionStartIndices(i) + val end = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n + for (j <- start until end) { + result(j) = i + } + } + result + } + + override def numPartitions: Int = partitionStartIndices.length + + override def getPartition(key: Any): Int = { + parentPartitionMapping(parent.getPartition(key)) + } + + override def equals(other: Any): Boolean = other match { + case c: CoalescedPartitioner => + c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices) + case _ => + false + } + + override def hashCode(): Int = 31 * parent.hashCode() + Arrays.hashCode(partitionStartIndices) +} + /** * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for * shuffling rows instead of Java key-value pairs. Note that something like this should eventually * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle * interfaces / internals. * - * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs. - * Partition ids should be in the range [0, numPartitions - 1]. - * @param serializer the serializer used during the shuffle. - * @param numPartitions the number of post-shuffle partitions. + * This RDD takes a [[ShuffleDependency]] (`dependency`), + * and a optional array of partition start indices as input arguments + * (`specifiedPartitionStartIndices`). + * + * The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle + * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * `dependency.partitioner` is the original partitioner used to partition + * map output, and `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions + * (i.e. the number of partitions of the map output). + * + * When `specifiedPartitionStartIndices` is defined, `specifiedPartitionStartIndices.length` + * will be the number of post-shuffle partitions. For this case, the `i`th post-shuffle + * partition includes `specifiedPartitionStartIndices[i]` to + * `specifiedPartitionStartIndices[i+1] - 1` (inclusive). + * + * When `specifiedPartitionStartIndices` is not defined, there will be + * `dependency.partitioner.numPartitions` post-shuffle partitions. For this case, + * a post-shuffle partition is created for every pre-shuffle partition. */ class ShuffledRowRDD( - @transient var prev: RDD[Product2[Int, InternalRow]], - serializer: Serializer, - numPartitions: Int) - extends RDD[InternalRow](prev.context, Nil) { + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None) + extends RDD[InternalRow](dependency.rdd.context, Nil) { - private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match { + case Some(indices) => indices + case None => + // When specifiedPartitionStartIndices is not defined, every post-shuffle partition + // corresponds to a pre-shuffle partition. + (0 until numPreShufflePartitions).toArray } - override val partitioner = Some(part) + private[this] val part: Partitioner = + new CoalescedPartitioner(dependency.partitioner, partitionStartIndices) + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override val partitioner: Option[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { - Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + assert(partitionStartIndices.length == part.numPartitions) + Array.tabulate[Partition](partitionStartIndices.length) { i => + val startIndex = partitionStartIndices(i) + val endIndex = + if (i < partitionStartIndices.length - 1) { + partitionStartIndices(i + 1) + } else { + numPreShufflePartitions + } + new ShuffledRowRDDPartition(i, startIndex, endIndex) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] - SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) - .read() - .asInstanceOf[Iterator[Product2[Int, InternalRow]]] - .map(_._2) + val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + // The range of pre-shuffle partitions that we are fetching at here is + // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. + val reader = + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + shuffledRowPartition.startPreShufflePartitionIndex, + shuffledRowPartition.endPreShufflePartitionIndex, + context) + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } override def clearDependencies() { super.clearDependencies() - prev = null + dependency = null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala new file mode 100644 index 000000000000..24207cb46fd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Performs (external) sorting. + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + override private[sql] lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + + val dataSize = longMetric("dataSize") + val spillSize = longMetric("spillSize") + + child.execute().mapPartitionsInternal { iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + + dataSize += sorter.getPeakMemoryUsage + spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore + + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) + sortedIterator + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 49adf215379c..e17b50edc62d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -38,6 +38,8 @@ object SortPrefixUtils { sortOrder.dataType match { case StringType => if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC + case BinaryType => + if (sortOrder.isAscending) PrefixComparators.BINARY else PrefixComparators.BINARY_DESC case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType => if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 50c27def8ea5..ec98f8104134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,28 +17,24 @@ package org.apache.spark.sql.execution +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ - -object SparkPlan { - protected[sql] val currentContext = new ThreadLocal[SQLContext]() -} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} +import org.apache.spark.sql.types.DataType /** - * :: DeveloperApi :: + * The base class for physical operators. */ -@DeveloperApi abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** @@ -47,24 +43,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SparkPlan.currentContext.get() + protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled will be set by the desserializer after the constructor has run. - val codegenEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.codegenEnabled + // the value of subexpressionEliminationEnabled will be set by the desserializer after the + // constructor has run. + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled } else { false } + /** + * Whether the "prepare" method is called. + */ + private val prepareCalled = new AtomicBoolean(false) + /** Overridden make copy also propogates sqlContext to copied plan. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { - SparkPlan.currentContext.set(sqlContext) + override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { + SQLContext.setActive(sqlContext) super.makeCopy(newArgs) } + /** + * Return all metadata that describes more details of this SparkPlan. + */ + private[sql] def metadata: Map[String, String] = Map.empty + + /** + * Return all metrics containing metrics of this SparkPlan. + */ + private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty + + /** + * Return a LongSQLMetric according to the name. + */ + private[sql] def longMetric(name: String): LongSQLMetric = + metrics(name).asInstanceOf[LongSQLMetric] + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -110,10 +128,31 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ "Operator will receive unsafe rows as input but cannot process unsafe rows") } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { + prepare() doExecute() } } + /** + * Prepare a SparkPlan for execution. It's idempotent. + */ + final def prepare(): Unit = { + if (prepareCalled.compareAndSet(false, true)) { + doPrepare() + children.foreach(_.prepare()) + } + } + + /** + * Overridden by concrete implementations of SparkPlan. It is guaranteed to run before any + * `execute` of SparkPlan. This is helpful if we want to set up some state before executing the + * query, e.g., `BroadcastHashJoin` uses it to broadcast asynchronously. + * + * Note: the prepare method has already walked down the tree, so the implementation doesn't need + * to call children's prepare methods. + */ + protected def doPrepare(): Unit = {} + /** * Overridden by concrete implementations of SparkPlan. * Produces the result of the query as an RDD[InternalRow] @@ -123,11 +162,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = { - execute().mapPartitions { iter => - val converter = CatalystTypeConverters.createToScalaConverter(schema) - iter.map(converter(_).asInstanceOf[Row]) - }.collect() + def executeCollect(): Array[InternalRow] = { + execute().map(_.copy()).collect() + } + + /** + * Runs this query returning the result as an array, using external Row format. + */ + def executeCollectPublic(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + executeCollect().map(converter(_).asInstanceOf[Row]) } /** @@ -135,9 +179,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * * This is modeled after RDD.take but never runs any job locally on the driver. */ - def executeTake(n: Int): Array[Row] = { + def executeTake(n: Int): Array[InternalRow] = { if (n == 0) { - return new Array[Row](0) + return new Array[InternalRow](0) } val childRDD = execute().map(_.copy()) @@ -171,92 +215,65 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ partsScanned += numPartsToTry } - val converter = CatalystTypeConverters.createToScalaConverter(schema) - buf.toArray.map(converter(_).asInstanceOf[Row]) + buf.toArray } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } - protected def newMutableProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) + expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } } } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } } } protected def newOrdering( - order: Seq[SortOrder], - inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - if (codegenEnabled) { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new RowOrdering(order, inputSchema) - } - } - } else { - new RowOrdering(order, inputSchema) + order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new InterpretedOrdering(order, inputSchema) + } + } + } + + /** + * Creates a row ordering for the given schema, in natural ascending order. + */ + protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = { + val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } + newOrdering(order, Seq.empty) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 000000000000..4f750ad13ab8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala new file mode 100644 index 000000000000..6e9a4df82824 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy + +class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { + val sparkContext: SparkContext = sqlContext.sparkContext + + def numPartitions: Int = sqlContext.conf.numShufflePartitions + + def strategies: Seq[Strategy] = + sqlContext.experimental.extraStrategies ++ ( + DataSourceStrategy :: + DDLStrategy :: + TakeOrderedAndProject :: + Aggregation :: + LeftSemiJoin :: + EquiJoinSelection :: + InMemoryScans :: + BasicOperators :: + BroadcastNestedLoop :: + CartesianProduct :: + DefaultJoin :: Nil) + + /** + * Used to build table scan operators where complex projection and filtering are done using + * separate physical operators. This function returns the given scan operator with Project and + * Filter nodes added only when needed. For example, a Project operator is only used when the + * final desired output requires complex expressions to be evaluated or when columns can be + * further eliminated out after filtering has been done. + * + * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized + * away by the filter pushdown optimization. + * + * The required attributes for both filtering and expression evaluation are passed to the + * provided `scanBuilder` function so that it can avoid unnecessary column materialization. + */ + def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + prunePushedDownFilters: Seq[Expression] => Seq[Expression], + scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition: Option[Expression] = + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) + filterCondition.map(Filter(_, scan)).getOrElse(scan) + } else { + val scan = scanBuilder((projectSet ++ filterSet).toSeq) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index ea8fce6ca9cf..b3e8d0d84937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType - /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * * @param fallback A function that parses an input string to a logical plan */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -100,14 +99,14 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case _ ~ dbName => ShowTablesCommand(dbName) } | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) + case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) + case None => logical.ShowFunctions(None, None) } ) private lazy val desc: Parser[LogicalPlan] = DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) + case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index b19ad4f1c563..45a8e0324826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,19 +22,17 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.{SparkConf, SparkEnv} + private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { val kryo = super.newKryo() @@ -43,16 +41,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) - kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], - new HyperLogLogSerializer) kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer) kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer) - // Specific hashsets must come first TODO: Move to core. - kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) - kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], - new OpenHashSetSerializer) kryo.register(classOf[Decimal]) kryo.register(classOf[JavaHashMap[_, _]]) @@ -62,7 +53,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } private[execution] class KryoResourcePool(size: Int) - extends ResourcePool[SerializerInstance](size) { + extends ResourcePool[SerializerInstance](size) { val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -86,7 +77,7 @@ private[sql] object SparkSqlSerializer { def serialize[T: ClassTag](o: T): Array[Byte] = acquireRelease { k => - k.serialize(o).array() + JavaUtils.bufferToArray(k.serialize(o)) } def deserialize[T: ClassTag](bytes: Array[Byte]): T = @@ -116,92 +107,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] { new java.math.BigDecimal(input.readString()) } } - -private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { - def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) { - val bytes = hyperLogLog.getBytes() - output.writeInt(bytes.length) - output.writeBytes(bytes) - } - - def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = { - val length = input.readInt() - val bytes = input.readBytes(length) - HyperLogLog.Builder.build(bytes) - } -} - -private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { - def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - val numItems = input.readInt() - val set = new OpenHashSet[Any](numItems + 1) - var i = 0 - while (i < numItems) { - val row = - new GenericInternalRow(rowSerializer.read( - kryo, - input, - classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) - set.add(row) - i += 1 - } - set - } -} - -private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { - def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value: Int = iterator.next() - output.writeInt(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { - val numItems = input.readInt() - val set = new IntegerHashSet - var i = 0 - while (i < numItems) { - val value = input.readInt() - set.add(value) - i += 1 - } - set - } -} - -private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { - def write(kryo: Kryo, output: Output, hs: LongHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value = iterator.next() - output.writeLong(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { - val numItems = input.readInt() - val set = new LongHashSet - var i = 0 - while (i < numItems) { - val value = input.readLong() - set.add(value) - i += 1 - } - set - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala deleted file mode 100644 index e5bbd0aaed0a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ /dev/null @@ -1,425 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io._ -import java.math.{BigDecimal, BigInteger} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.Logging -import org.apache.spark.serializer._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in - * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the - * [[Product2]] are constructed based on their schemata. - * The benefit of this serialization stream is that compared with general-purpose serializers like - * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower - * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: - * 1. It does not support complex types, i.e. Map, Array, and Struct. - * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when - * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because - * the objects passed in the serializer are not in the type of [[Product2]]. Also also see - * the comment of the `serializer` method in [[Exchange]] for more information on it. - */ -private[sql] class Serializer2SerializationStream( - rowSchema: Array[DataType], - out: OutputStream) - extends SerializationStream with Logging { - - private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) - - override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] - writeKey(kv._1) - writeValue(kv._2) - - this - } - - override def writeKey[T: ClassTag](t: T): SerializationStream = { - // No-op. - this - } - - override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[InternalRow]) - this - } - - def flush(): Unit = { - rowOut.flush() - } - - def close(): Unit = { - rowOut.close() - } -} - -/** - * The corresponding deserialization stream for [[Serializer2SerializationStream]]. - */ -private[sql] class Serializer2DeserializationStream( - rowSchema: Array[DataType], - in: InputStream) - extends DeserializationStream with Logging { - - private val rowIn = new DataInputStream(new BufferedInputStream(in)) - - private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = { - if (schema == null) { - () => null - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } - } - - // Functions used to return rows for key and value. - private val getRow = rowGenerator(rowSchema) - // Functions used to read a serialized row from the InputStream and deserialize it. - private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) - - override def readObject[T: ClassTag](): T = { - readValue() - } - - override def readKey[T: ClassTag](): T = { - null.asInstanceOf[T] // intentionally left blank. - } - - override def readValue[T: ClassTag](): T = { - readRowFunc(getRow()).asInstanceOf[T] - } - - override def close(): Unit = { - rowIn.close() - } -} - -private[sql] class SparkSqlSerializer2Instance( - rowSchema: Array[DataType]) - extends SerializerInstance { - - def serialize[T: ClassTag](t: T): ByteBuffer = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException("Not supported.") - - def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(rowSchema, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(rowSchema, s) - } -} - -/** - * SparkSqlSerializer2 is a special serializer that creates serialization function and - * deserialization function based on the schema of data. It assumes that values passed in - * are Rows. - */ -private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) - extends Serializer - with Logging - with Serializable{ - - def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) - - override def supportsRelocationOfSerializedObjects: Boolean = { - // SparkSqlSerializer2 is stateless and writes no stream headers - true - } -} - -private[sql] object SparkSqlSerializer2 { - - final val NULL = 0 - final val NOT_NULL = 1 - - /** - * Check if rows with the given schema can be serialized with ShuffleSerializer. - * Right now, we do not support a schema having complex types or UDTs, or all data types - * of fields are NullTypes. - */ - def support(schema: Array[DataType]): Boolean = { - if (schema == null) return true - - var allNullTypes = true - var i = 0 - while (i < schema.length) { - schema(i) match { - case NullType => // Do nothing - case udt: UserDefinedType[_] => - allNullTypes = false - return false - case array: ArrayType => - allNullTypes = false - return false - case map: MapType => - allNullTypes = false - return false - case struct: StructType => - allNullTypes = false - return false - case _ => - allNullTypes = false - } - i += 1 - } - - // If types of fields are all NullTypes, we return false. - // Otherwise, we return true. - return !allNullTypes - } - - /** - * The util function to create the serialization function based on the given schema. - */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) - : InternalRow => Unit = { - (row: InternalRow) => - // If the schema is null, the returned function does nothing when it get called. - if (schema != null) { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we write values to the underlying stream, we also first write the null byte - // first. Then, if the value is not null, we write the contents out. - - case NullType => // Write nothing. - - case BooleanType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeBoolean(row.getBoolean(i)) - } - - case ByteType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeByte(row.getByte(i)) - } - - case ShortType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeShort(row.getShort(i)) - } - - case IntegerType | DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getInt(i)) - } - - case LongType | TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getLong(i)) - } - - case FloatType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeFloat(row.getFloat(i)) - } - - case DoubleType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeDouble(row.getDouble(i)) - } - - case StringType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getUTF8String(i).getBytes - out.writeInt(bytes.length) - out.write(bytes) - } - - case BinaryType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getBinary(i) - out.writeInt(bytes.length) - out.write(bytes) - } - - case decimal: DecimalType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val value = row.getDecimal(i, decimal.precision, decimal.scale) - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray - out.writeInt(bytes.length) - out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) - } - } - i += 1 - } - } - } - - /** - * The util function to create the deserialization function based on the given schema. - */ - def createDeserializationFunction( - schema: Array[DataType], - in: DataInputStream): (MutableRow) => InternalRow = { - if (schema == null) { - (mutableRow: MutableRow) => null - } else { - (mutableRow: MutableRow) => { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we read values from the underlying stream, we also first read the null byte - // first. Then, if the value is not null, we update the field of the mutable row. - - case NullType => mutableRow.setNullAt(i) // Read nothing. - - case BooleanType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setBoolean(i, in.readBoolean()) - } - - case ByteType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setByte(i, in.readByte()) - } - - case ShortType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setShort(i, in.readShort()) - } - - case IntegerType | DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setInt(i, in.readInt()) - } - - case LongType | TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setLong(i, in.readLong()) - } - - case FloatType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setFloat(i, in.readFloat()) - } - - case DoubleType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setDouble(i, in.readDouble()) - } - - case StringType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) - } - - case BinaryType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, bytes) - } - - case decimal: DecimalType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - // First, read in the unscaled value. - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - } - i += 1 - } - - mutableRow - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13..688555cf136e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,19 +19,18 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Strategy, execution} +import org.apache.spark.sql.{Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { - self: SQLContext#SparkPlanner => + self: SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -63,19 +62,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching join keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], @@ -83,125 +83,49 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: joins.BuildSide) = { + side: joins.BuildSide): Seq[SparkPlan] = { val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + if RowOrdering.isOrderable(leftKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) - condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys( - LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys( - RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.ShuffledHashOuterJoin( + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case _ => Nil - } - } - - object HashAggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Aggregations that can be performed in two phases, before and after the shuffle. - - // Cases where all aggregates can be codegened. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) - if canBeCodeGened( - allAggregates(partialComputation) ++ - allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled && - !canBeConvertedToNewAggregation(plan) => - execution.GeneratedAggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - unsafeEnabled, - execution.GeneratedAggregate( - partial = true, - groupingExpressions, - partialComputation, - unsafeEnabled, - planLater(child))) :: Nil - - // Cases where some aggregate can not be codegened - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil } - - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { - case a: logical.Aggregate => - if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { - a.newAggregation.isDefined - } else { - Utils.checkInvalidAggregateFunction2(a) - false - } - case _ => false - } - - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true - // The generated set implementation is pretty limited ATM. - case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => true - case _ => false - } - - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = - exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } /** @@ -209,80 +133,122 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && - sqlContext.conf.codegenEnabled => - val converted = p.newAggregation - converted match { - case None => Nil // Cannot convert to new aggregation code path. - case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // Extracts all distinct aggregate expressions from the resultExpressions. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionMap = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - (aggregateFunction, agg.isDistinct) -> - Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - }.toMap - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets (aggregate.NewAggregation will not match). - sys.error( - "Multiple distinct column sets are not supported by the new aggregation" + - "code path.") - } + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - aggregateFunctionMap, - resultExpressions, - planLater(child)) - } else { - aggregate.Utils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - aggregateFunctionMap, - resultExpressions, - planLater(child)) - } - - aggregateOperator + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + val aggregateOperator = + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } + } else if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + namedGroupingExpressions.map(_._2), + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } + + aggregateOperator case _ => Nil } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemi => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -291,6 +257,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { @@ -321,62 +302,42 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions - /** - * Picks an appropriate sort operator. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ - def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) - } else { - execution.Sort(sortExprs, global, child) - } - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + + case logical.MapPartitions(f, tEnc, uEnc, output, child) => + execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil + case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => + execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil + case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => + execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, + leftGroup, rightGroup, left, right) => + execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, + planLater(left), planLater(right)) :: Nil + case logical.Repartition(numPartitions, shuffle, child) => - execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil + if (shuffle) { + execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + } else { + execution.Coalesce(numPartitions, planLater(child)) :: Nil + } case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - getSortOperator(sortExprs, global = false, planLater(child)) :: Nil + execution.Sort(sortExprs, global = false, child = planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - getSortOperator(sortExprs, global, planLater(child)):: Nil + execution.Sort(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => - // If unsafe mode is enabled and we support these data types in Unsafe, use the - // Tungsten project. Otherwise, use the normal project. - if (sqlContext.conf.unsafeEnabled && - UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { - execution.TungstenProject(projectList, planLater(child)) :: Nil - } else { - execution.Project(projectList, planLater(child)) :: Nil - } + execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled - if (useNewAggregation && a.newAggregation.isDefined) { - // If this logical.Aggregate can be planned to use new aggregation code path - // (i.e. it can be planned by the Strategy Aggregation), we will not use the old - // aggregation code path. - Nil - } else { - Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - } - } case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => execution.Window( projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil @@ -396,35 +357,36 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => - execution.PhysicalRDD(Nil, singleRowRdd) :: Nil - case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil + case logical.RepartitionByExpression(expressions, child, nPartitions) => + execution.Exchange(HashPartitioning( + expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil - case BroadcastHint(child) => apply(child) + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil + case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } } object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => + case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( - tableName, userSpecifiedSchema, provider, opts)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts)) :: Nil case c: CreateTableUsing if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query) + case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) if partitionsCols.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) => + case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => val cmd = CreateTempTableUsingAsSelect( - tableName, provider, Array.empty[String], mode, opts, query) + tableIdent, provider, Array.empty[String], mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 16498da080c8..7e981268de39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -45,32 +45,27 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S } private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { - - /** - * Marks the end of a stream written with [[serializeStream()]]. - */ - private[this] val EOF: Int = -1 - /** * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. - * The end of the stream is denoted by a record with the special length `EOF` (-1). */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) - private[this] val dOut: DataOutputStream = new DataOutputStream(out) + private[this] val dOut: DataOutputStream = + new DataOutputStream(new BufferedOutputStream(out)) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] + dOut.writeInt(row.getSizeInBytes) - row.writeToStream(out, writeBuffer) + row.writeToStream(dOut, writeBuffer) this } override def writeKey[T: ClassTag](key: T): SerializationStream = { // The key is only needed on the map side when computing partition ids. It does not need to // be shuffled. - assert(key.isInstanceOf[Int]) + assert(null == key || key.isInstanceOf[Int]) this } @@ -90,33 +85,42 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - dOut.writeInt(EOF) dOut.close() } } override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + private[this] val EOF: Int = -1 override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { - private[this] var rowSize: Int = dIn.readInt() + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF + } + + private[this] var rowSize: Int = readSize() override def hasNext: Boolean = rowSize != EOF override def next(): (Int, UnsafeRow) = { if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) - rowSize = dIn.readInt() // read the next row's size + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed @@ -147,8 +151,8 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index fe9f2c702817..9852b6e7beeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.execution -import java.util +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.CompactBuffer -import scala.collection.mutable /** - * :: DeveloperApi :: * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) * partition. The aggregates are calculated for each row in the group. Special processing * instructions, frames, are used to calculate these aggregates. Frames are processed in the order @@ -47,6 +45,8 @@ import scala.collection.mutable * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -76,7 +76,6 @@ import scala.collection.mutable * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ -@DeveloperApi case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], @@ -101,6 +100,8 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def canProcessUnsafeRows: Boolean = true + /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -126,12 +127,10 @@ case class Window( // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = - if (sortExpr.direction == Descending) { - -offset - } else { - offset - } + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() @@ -143,54 +142,112 @@ case class Window( // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val (sortExprs, schema) = exprs.map { case e => - val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() - (SortOrder(ref, e.direction), ref) - }.unzip - val ordering = newOrdering(sortExprs, schema) + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) RangeBoundOrdering(ordering, current, bound) case RowFrame => RowBoundOrdering(offset) } } /** - * Create a frame processor. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame boundaries. - * @param functions to process in the frame. - * @param ordinal at which the processor starts writing to the output. - * @return a frame processor. + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. */ - private[this] def createFrameProcessor( - frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es.append(e) + fns.append(fn) + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + newMutableProjection, + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } } /** @@ -203,55 +260,29 @@ case class Window( */ private[this] def createResultProjection( expressions: Seq[Expression]): MutableProjection = { - val unboundToAttr = expressions.map { - e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) } - val unboundToAttrMap = unboundToAttr.toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) newMutableProjection( projectList ++ patchedWindowExpression, - child.output ++ unboundToAttr.map(_._2))() + child.output)() } protected override def doExecute(): RDD[InternalRow] = { - // Prepare processing. - // Group the window expression by their processing frame. - val windowExprs = windowExpression.flatMap { - _.collect { - case e: WindowExpression => e - } - } - - // Create Frame processor factories and order the unbound window expressions by the frame they - // are processed in; this is the order in which their results will be written to window - // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = mutable.Buffer.empty[Expression] - framedWindowExprs.zipWithIndex.foreach { - case ((frame, unboundFrameExpressions), index) => - // Track the ordinal. - val ordinal = unboundExpressions.size - - // Track the unbound expressions - unboundExpressions ++= unboundFrameExpressions - - // Bind the expressions. - val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) - }.toArray - - // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) - } + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -270,13 +301,15 @@ case class Window( fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val rows = ArrayBuffer.empty[InternalRow] + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup - rows = new CompactBuffer + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + rows.clear() while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() fetchNextRow() @@ -300,7 +333,6 @@ case class Window( override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -311,7 +343,7 @@ case class Window( // Get the results for the window frames. var i = 0 while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write() i += 1 } @@ -358,140 +390,96 @@ private[execution] final case class RangeBoundOrdering( * A window function calculates the results of a number of window functions for a window frame. * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { - - // Make sure functions are initialized. - functions.foreach(_.init()) - - /** Number of columns the window function frame is managing */ - val numColumns = functions.length - - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame - - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) - +private[execution] abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * * @param rows to calculate the frame results for. */ - def prepare(rows: CompactBuffer[InternalRow]): Unit + def prepare(rows: ArrayBuffer[InternalRow]): Unit /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. + * Write the current results to the target row. */ - def write(target: GenericMutableRow): Unit + def write(): Unit +} - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 - } - } +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared - } + /** Rows of the partition currently being processed. */ + private[this] var input: ArrayBuffer[InternalRow] = null - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } - } - evaluate() - } + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() - i += 1 - } - } + /** Index of the current output row. */ + private[this] var outputIndex = 0 - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Collect the expressions and bind them. + val numInputAttributes = inputSchema.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputSchema) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + input + } else { + // With default value. + val default = BindReferences.bindReference(e.default, inputSchema).transform { + // Shift the input reference to its default version. + case BoundReference(o, dataType, nullable) => + BoundReference(o + numInputAttributes, dataType, nullable) + } + org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + } + case e => + BindReferences.bindReference(e, inputSchema) } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 - } + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + input = rows + inputIndex = offset + outputIndex = 0 } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(): Unit = { + val size = input.size + if (inputIndex >= 0 && inputIndex < size) { + join(input(inputIndex), input(outputIndex)) + } else { + join(emptyRow, input(outputIndex)) } + projection(join) + inputIndex += 1 + outputIndex += 1 } } @@ -499,19 +487,19 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -521,30 +509,25 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new util.ArrayDeque[Array[AnyRef]] - /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputHighIndex = 0 inputLowIndex = 0 outputIndex = 0 - buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + ubound.compare(input, inputHighIndex, outputIndex) <= 0) { inputHighIndex += 1 bufferUpdated = true } @@ -552,25 +535,21 @@ private[execution] final class SlidingWindowFunctionFrame( // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() + lbound.compare(input, inputLowIndex, outputIndex) < 0) { inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputLowIndex, inputHighIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -581,36 +560,25 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 - val iterator = rows.iterator - while (iterator.hasNext) { - update(iterator.next()) - } - evaluate() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + processor.initialize(rows.size) + processor.update(rows, 0, rows.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -623,58 +591,53 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + * output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 + processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + processor.update(input(inputIndex)) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -689,45 +652,34 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * buffer and must do full recalculation after each row. Reverse iteration would be possible, if * the communitativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + * current output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Drop all rows from the buffer for which the input row value is smaller than @@ -739,15 +691,151 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputIndex, input.size) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } +} + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[execution] object AggregateProcessor { + def apply(functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) + if (trackPartitionSize) { + aggBufferAttributes += SizeBasedWindowFunction.n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProjection = newMutableProjection( + initialValues, + Seq(SizeBasedWindowFunction.n))() + val updateProjection = newMutableProjection( + updateExpressions, + aggBufferAttributes ++ inputAttributes)() + val evaluateProjection = newMutableProjection( + evaluateExpressions, + aggBufferAttributes)() + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProjection, + updateProjection, + evaluateProjection, + imperatives.toArray, + trackPartitionSize) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } + } + + /** Bulk update the given buffer. */ + def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = { + var i = begin + while (i < end) { + update(input(i)) + i += 1 + } + } - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala deleted file mode 100644 index cf568dc04867..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType - -/** - * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types - * of the grouping expressions and aggregate functions, it determines if it uses - * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to - * process input rows. - */ -case class Aggregate( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - private[this] val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - private[this] val hasNonAlgebricAggregateFunctions = - !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) - - // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of - // grouping key and aggregation buffer is supported; and (3) all - // aggregate functions are algebraic. - private[this] val supportsHybridIterator: Boolean = { - val aggregationBufferSchema: StructType = - StructType.fromAttributes( - allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) - val groupKeySchema: StructType = - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) - - val schemaSupportsUnsafe: Boolean = - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - - // TODO: Use the hybrid iterator for non-algebric aggregate functions. - sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions - } - - // We need to use sorted input if we have grouping expressions, and - // we cannot use the hybrid iterator or the hybrid is disabled. - private[this] val requiresSortedInput: Boolean = { - groupingExpressions.nonEmpty && !supportsHybridIterator - } - - override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions - - // If result expressions' data types are all fixed length, we generate unsafe rows - // (We have this requirement instead of check the result of UnsafeProjection.canSupport - // is because we use a mutable projection to generate the result). - override def outputsUnsafeRows: Boolean = { - // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) - // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix - // any issue we get. - false - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - if (requiresSortedInput) { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } else { - Seq.fill(children.size)(Nil) - } - } - - override def outputOrdering: Seq[SortOrder] = { - if (requiresSortedInput) { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } else { - Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - val useHybridIterator = - hasInput && - supportsHybridIterator && - groupingExpressions.nonEmpty - if (useHybridIterator) { - UnsafeHybridAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _, - child.output, - iter, - outputsUnsafeRows) - } else { - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[InternalRow]() - } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _ , - newProjection _, - child.output, - iter, - outputsUnsafeRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - } - - override def simpleString: String = { - val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { - classOf[UnsafeHybridAggregationIterator].getSimpleName - } else { - classOf[SortBasedAggregationIterator].getSimpleName - } - - s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index abca373b0c4f..0c74df0aa5fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.unsafe.KVIterator - -import scala.collection.mutable.ArrayBuffer /** - * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -34,92 +33,97 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. - */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction2] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = allAggregateExpressions(i).mode match { - case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { + case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an - // AlgebraicAggregate (it does not support code-gen) and the mode of + // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. - func.withNewInputBufferOffset(inputBufferOffset) - inputBufferOffset += func.bufferSchema.length - func + val updatedFunc = func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case function => function + } + inputBufferOffset += func.aggBufferSchema.length + updatedFunc } - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset) - mutableBufferOffset += funcWithBoundReferences.bufferSchema.length - functions(i) = funcWithBoundReferences + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { + case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case function => function + } + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset i += 1 } functions } - // Positions of those non-algebraic aggregate functions in allAggregateFunctions. + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are non-algebraic aggregate functions. - // nonAlgebraicAggregateFunctionPositions will be [1, 2]. - private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = { + // func2 and func3 are imperative aggregate functions. + // ImperativeAggregateFunctionPositions will be [1, 2]. + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: AlgebraicAggregate => + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: DeclarativeAggregate => case _ => positions += i } i += 1 @@ -127,364 +131,131 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - nonCompleteAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // The projection used to initialize buffer values for all AlgebraicAggregates. - private[this] val algebraicInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.initialValues - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + // The projection used to initialize buffer values for all expression-based aggregates. + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.initialValues + // For the positions corresponding to imperative aggregate functions, we'll use special + // no-op expressions which are ignored during projection code-generation. + case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } newMutableProjection(initExpressions, Nil)() } - // All non-Algebraic AggregateFunctions. - private[this] val allNonAlgebraicAggregateFunctions = - allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions) - - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// + // All imperative AggregateFunctions. + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allImperativeAggregateFunctionPositions + .map(aggregateFunctions) + .map(_.asInstanceOf[ImperativeAggregate]) // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - val algebraicUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - algebraicUpdateProjection.target(currentBuffer) - // Process all algebraic aggregate functions. - algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all AlgebraicAggregates. - val algebraicMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All non-algebraic aggregate functions with mode Complete. - val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalAlgebraicMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - val completeAlgebraicUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalAlgebraicMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All non-algebraic aggregate functions with mode Complete. - val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - val completeAlgebraicUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutoutRow = new GenericMutableRow(resultExpressions.length) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow) - } else { - safeOutoutRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.bufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(currentBuffer) - // Generate results for all non-algebraic aggregate functions. - var i = 0 - while (i < allNonAlgebraicAggregateFunctions.length) { - aggregateResult.update( - allNonAlgebraicAggregateFunctionPositions(i), - allNonAlgebraicAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) - } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { - algebraicInitialProjection.target(buffer)(EmptyRow) + expressionAggInitialProjection.target(buffer)(EmptyRow) var i = 0 - while (i < allNonAlgebraicAggregateFunctions.length) { - allNonAlgebraicAggregateFunctions(i).initialize(buffer) + while (i < allImperativeAggregateFunctions.length) { + allImperativeAggregateFunctions(i).initialize(buffer) i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow -} - -object AggregationIterator { - def kvIterator( - groupingExpressions: Seq[NamedExpression], - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) - - private[this] var groupingKey: InternalRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): InternalRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } - - def unsafeKVIterator( - groupingExpressions: Seq[NamedExpression], - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { - new KVIterator[UnsafeRow, InternalRow] { - private[this] val groupingKeyGenerator = - UnsafeProjection.create(groupingExpressions, inputAttributes) - - private[this] var groupingKey: UnsafeRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator.apply(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala new file mode 100644 index 000000000000..c5470a6989de --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.metric.SQLMetrics + +case class SortBasedAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsInternal { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() + } else { + val outputIter = new SortBasedAggregationIterator( + groupingExpressions, + child.output, + iter, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + numInputRows, + numOutputRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + numOutputRows += 1 + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 78bcee16c9d0..ac920aa8bc7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,45 +19,43 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.KVIterator +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.execution.metric.LongSQLMetric /** - * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + inputIterator: Iterator[InternalRow], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { val unsafeProjection = @@ -75,10 +73,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -89,6 +87,22 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingProjection(inputRow).copy() + firstRowInNextGroup = inputRow.copy() + numInputRows += 1 + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -100,17 +114,15 @@ class SortBasedAggregationIterator( // The search will stop when we see the next group or there is no // input row left in the iter. - var hasNext = inputKVIterator.next() - while (!findNextPartition && hasNext) { + while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val groupingKey = inputKVIterator.getKey - val currentRow = inputKVIterator.getValue + val currentRow = inputIterator.next() + val groupingKey = groupingProjection(currentRow) + numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - - hasNext = inputKVIterator.next() } else { // We find a new group. findNextPartition = true @@ -131,7 +143,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -139,7 +151,7 @@ class SortBasedAggregationIterator( val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. initializeBuffer(sortBasedAggregationBuffer) - + numOutputRows += 1 outputRow } else { // no more result @@ -147,90 +159,8 @@ class SortBasedAggregationIterator( } } - protected def initialize(): Unit = { - if (inputKVIterator.next()) { - initializeBuffer(sortBasedAggregationBuffer) - - nextGroupingKey = inputKVIterator.getKey().copy() - firstRowInNextGroup = inputKVIterator.getValue().copy() - - sortedInputHasNewGroup = true - } else { - // This inputIter is empty. - sortedInputHasNewGroup = false - } - } - - initialize() - - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) - } -} - -object SortBasedAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { - AggregationIterator.unsafeKVIterator( - groupingExprs, - inputAttributes, - inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] - } else { - AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) - } - - new SortBasedAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - kvIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - new SortBasedAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } - // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala new file mode 100644 index 000000000000..b8849c827048 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.types.StructType + +case class TungstenAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) + + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to sort-based + // aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[Int] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => Some(fallbackStartsAt.toInt) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") + val dataSize = longMetric("dataSize") + val spillSize = longMetric("spillSize") + + child.execute().mapPartitions { iter => + + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter, + testFallbackStartsAt, + numInputRows, + numOutputRows, + dataSize, + spillSize) + if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = aggregateExpressions + + testFallbackStartsAt match { + case None => + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" + case Some(fallbackStartsAt) => + s"TungstenAggregateWithControlledFallback $groupingExpressions " + + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" + } + } +} + +object TungstenAggregate { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala new file mode 100644 index 000000000000..582fdbe54706 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If we + * this map cannot allocate memory from memory manager, it spill the map into disk + * and create a new one. After processed all the input, then merge all the spills + * together using external sorter, and do sort-based aggregation. + * + * The process has the following step: + * - Step 0: Do hash-based aggregation. + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map. + * - Step 3: Get a sorted [[KVIterator]] from the external sorter. + * - Step 4: Repeat step 0 until no more input. + * - Step 5: Initialize sort-based aggregation on the sorted iterator. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: Methods and fields used when we switch to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param groupingExpressions + * expressions for grouping keys + * @param aggregateExpressions + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param aggregateAttributes the attributes of the aggregateExpressions' + * outputs when they are stored in the final aggregation buffer. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class TungstenAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + testFallbackStartsAt: Option[Int], + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric, + dataSize: LongSQLMetric, + spillSize: LongSQLMetric) + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // Creates a new aggregation buffer and initializes buffer values. + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) + val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) + .apply(new GenericMutableRow(bufferSchema.length)) + // Initialize declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initialize imperative aggregates' buffer values + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + buffer + } + + // Creates a function used to generate output rows. + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() + } + } + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If there is not enough memory, it will multiple hash-maps, spilling + // after each becomes full then using sort to merge these spills, finally do sort + // based aggregation. + private def processInputs(fallbackStartsAt: Int): Unit = { + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + // Note that it would be better to eliminate the hash map entirely in the future. + val groupingKey = groupingProjection.apply(null) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + while (inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 + processRow(buffer, newInput) + } + } else { + var i = 0 + while (inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 + val groupingKey = groupingProjection.apply(newInput) + var buffer: UnsafeRow = null + if (i < fallbackStartsAt) { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } + if (buffer == null) { + val sorter = hashMap.destructAndCreateExternalSorter() + if (externalSorter == null) { + externalSorter = sorter + } else { + externalSorter.merge(sorter) + } + i = 0 + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation") + } + } + processRow(buffer, newInput) + i += 1 + } + + if (externalSorter != null) { + val sorter = hashMap.destructAndCreateExternalSorter() + externalSorter.merge(sorter) + hashMap.free() + + switchToSortBasedAggregation() + } + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: Methods and fields used when we switch to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This sorter is used for sort-based aggregation. It is initialized as soon as + // we switch from hash-based to sort-based aggregation. Otherwise, it is not used. + private[this] var externalSorter: UnsafeKVExternalSorter = null + + /** + * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. + */ + private def switchToSortBasedAggregation(): Unit = { + logInfo("falling back to sort based aggregation.") + + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _) => + agg.copy(mode = Final) + case other => other + } + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + /** + * Start processing input rows. + */ + processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + val res = if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + + // If this is the last record, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + if (!hasNext) { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) + dataSize += peakMemory + spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + } + numOutputRows += 1 + res + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: Utility functions + /////////////////////////////////////////////////////////////////////////// + + /** + * Generate a output row when there is no input and there is no grouping expression. + */ + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 000000000000..a9719128a626 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +object TypedAggregateExpression { + def apply[A, B : Encoder, C : Encoder]( + aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], + Nil, + 0, + 0) + } +} + +/** + * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has + * the following limitations: + * - It assumes the aggregator has a zero, `0`. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], + cEncoder: ExpressionEncoder[Any], + children: Seq[Attribute], + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = if (cEncoder.flat) { + cEncoder.schema.head.dataType + } else { + cEncoder.schema + } + + override def deterministic: Boolean = true + + override lazy val resolved: Boolean = aEncoder.isDefined + + override lazy val inputTypes: Seq[DataType] = Nil + + override val aggBufferSchema: StructType = unresolvedBEncoder.schema + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // We let the dataset do the binding for us. + lazy val boundA = aEncoder.get + + private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { + var i = 0 + while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i + aggBufferSchema(i).dataType match { + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) + } + i += 1 + } + } + + override def initialize(buffer: MutableRow): Unit = { + val zero = bEncoder.toRow(aggregator.zero) + updateBuffer(buffer, zero) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputA = boundA.fromRow(input) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) + val merged = aggregator.reduce(currentB, inputA) + val returned = bEncoder.toRow(merged) + + updateBuffer(buffer, returned) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) + val merged = aggregator.merge(b1, b2) + val returned = bEncoder.toRow(merged) + + updateBuffer(buffer1, returned) + } + + override def eval(buffer: InternalRow): Any = { + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) + val result = cEncoder.toRow(aggregator.finish(b)) + dataType match { + case _: StructType => result + case _ => result.get(0, dataType) + } + } + + override def toString: String = { + s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = aggregator.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala deleted file mode 100644 index 37d34eb7ccf0..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala +++ /dev/null @@ -1,398 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap} -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.types.StructType - -/** - * An iterator used to evaluate [[AggregateFunction2]]. - * It first tries to use in-memory hash-based aggregation. If we cannot allocate more - * space for the hash map, we spill the sorted map entries, free the map, and then - * switch to sort-based aggregation. - */ -class UnsafeHybridAggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends AggregationIterator( - groupingKeyAttributes, - valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - require(groupingKeyAttributes.nonEmpty) - - /////////////////////////////////////////////////////////////////////////// - // Unsafe Aggregation buffers - /////////////////////////////////////////////////////////////////////////// - - // This is the Unsafe Aggregation Map used to store all buffers. - private[this] val buffers = new UnsafeFixedWidthAggregationMap( - newBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), - StructType.fromAttributes(groupingKeyAttributes), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), - false // disable tracking of performance metrics - ) - - override protected def newBuffer: UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initializeBuffer(buffer) - buffer - } - - /////////////////////////////////////////////////////////////////////////// - // Methods and variables related to switching to sort-based aggregation - /////////////////////////////////////////////////////////////////////////// - private[this] var sortBased = false - - private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ - - // The value part of the input KV iterator is used to store original input values of - // aggregate functions, we need to convert them to aggregation buffers. - private def processOriginalInput( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val buffer: UnsafeRow = newBuffer - - override def next(): Boolean = { - initializeBuffer(buffer) - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - processRow(buffer, firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - val value = inputKVIterator.getValue() - processRow(buffer, value) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - buffer - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. - // We need to project the aggregation buffer out. - private def projectInputBufferToUnsafe( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - - private[this] val value: UnsafeRow = { - val genericMutableRow = new GenericMutableRow(bufferSchema.length) - UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) - } - - private[this] val projectInputBuffer = { - newMutableProjection(bufferSchema, valueAttributes)().target(value) - } - - override def next(): Boolean = { - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - projectInputBuffer(firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - projectInputBuffer(inputKVIterator.getValue()) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - value - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - /** - * We need to fall back to sort based aggregation because we do not have enough memory - * for our in-memory hash map (i.e. `buffers`). - */ - private def switchToSortBasedAggregation( - currentGroupingKey: UnsafeRow, - currentRow: InternalRow): Unit = { - logInfo("falling back to sort based aggregation.") - - // Step 1: Get the ExternalSorter containing entries of the map. - val externalSorter = buffers.destructAndCreateExternalSorter() - - // Step 2: Free the memory used by the map. - buffers.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, - // we need to process them to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from the input. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - val processedIterator = if (needsProcess) { - processOriginalInput(currentGroupingKey, currentRow) - } else { - // The input value's format is groupingExprs + buffer. - // We need to project the buffer part out. - projectInputBufferToUnsafe(currentGroupingKey, currentRow) - } - - // Step 4: Redirect processedIterator to externalSorter. - while (processedIterator.next()) { - externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) - } - - // Step 5: Get the sorted iterator from the externalSorter. - val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator() - - // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. - // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator - // will be PartialMerge. For a aggregate function with mode Complete, - // its mode in the SortBasedAggregationIterator will be Final. - val newNonCompleteAggregateExpressions = allAggregateExpressions.map { - case AggregateExpression2(func, Partial, isDistinct) => - AggregateExpression2(func, PartialMerge, isDistinct) - case AggregateExpression2(func, Complete, isDistinct) => - AggregateExpression2(func, Final, isDistinct) - case other => other - } - val newNonCompleteAggregateAttributes = - nonCompleteAggregateAttributes ++ completeAggregateAttributes - - val newValueAttributes = - allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) - - sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( - groupingKeyAttributes = groupingKeyAttributes, - valueAttributes = newValueAttributes, - inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], - nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, - nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - newMutableProjection = newMutableProjection, - outputsUnsafeRows = outputsUnsafeRows) - } - - /////////////////////////////////////////////////////////////////////////// - // Methods used to initialize this iterator. - /////////////////////////////////////////////////////////////////////////// - - /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ - protected def initialize(): Unit = { - var hasNext = inputKVIterator.next() - while (!sortBased && hasNext) { - val groupingKey = inputKVIterator.getKey() - val currentRow = inputKVIterator.getValue() - val buffer = buffers.getAggregationBuffer(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, currentRow) - sortBased = true - } else { - processRow(buffer, currentRow) - hasNext = inputKVIterator.next() - } - } - } - - // This is the starting point of this iterator. - initialize() - - // Creates the iterator for the Hash Aggregation Map after we have populated - // contents of that map. - private[this] val aggregationBufferMapIterator = buffers.iterator() - - private[this] var _mapIteratorHasNext = false - - // Pre-load the first key-value pair from the map to make hasNext idempotent. - if (!sortBased) { - _mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!_mapIteratorHasNext) { - buffers.free() - } - } - - /////////////////////////////////////////////////////////////////////////// - // Iterator's public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = { - (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) - } - - - override final def next(): InternalRow = { - if (hasNext) { - if (sortBased) { - sortBasedAggregationIterator.next() - } else { - // We did not fall back to the sort-based aggregation. - val result = - generateOutput( - aggregationBufferMapIterator.getKey, - aggregationBufferMapIterator.getValue) - // Pre-load next key-value pair form aggregationBufferMapIterator. - _mapIteratorHasNext = aggregationBufferMapIterator.next() - - if (!_mapIteratorHasNext) { - val resultCopy = result.copy() - buffers.free() - resultCopy - } else { - result - } - } - } else { - // no more result - throw new NoSuchElementException - } - } -} - -object UnsafeHybridAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - // scalastyle:on -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 5fafc916bfa0..c0d00104e8bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -289,34 +318,38 @@ private[sql] class InputAggregationBuffer private[sql] ( /** * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the * internal aggregation code path. - * @param children - * @param udaf */ private[sql] case class ScalaUDAF( children: Seq[Expression], - udaf: UserDefinedAggregateFunction) - extends AggregateFunction2 with Logging { + udaf: UserDefinedAggregateFunction, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - require( - children.length == udaf.inputSchema.length, - s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + - s"but ${children.length} are provided.") + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def nullable: Boolean = true - override def dataType: DataType = udaf.returnDataType + override def dataType: DataType = udaf.dataType override def deterministic: Boolean = udaf.deterministic override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) - override val bufferSchema: StructType = udaf.bufferSchema + override val aggBufferSchema: StructType = udaf.bufferSchema - override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) - private[this] val childrenSchema: StructType = { + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) @@ -337,67 +370,53 @@ private[sql] case class ScalaUDAF( } } - private[this] val inputToScalaConverters: Any => Any = + private[this] lazy val inputToScalaConverters: Any => Any = CatalystTypeConverters.createToScalaConverter(childrenSchema) - private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = { - bufferSchema.fields.map { field => + private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = { + aggBufferSchema.fields.map { field => CatalystTypeConverters.createToCatalystConverter(field.dataType) } } - private[this] val bufferValuesToScalaConverters: Array[Any => Any] = { - bufferSchema.fields.map { field => + private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = { + aggBufferSchema.fields.map { field => CatalystTypeConverters.createToScalaConverter(field.dataType) } } - // This buffer is only used at executor side. - private[this] var inputAggregateBuffer: InputAggregationBuffer = null + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } // This buffer is only used at executor side. - private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null + private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + inputAggBufferOffset, + null) + } // This buffer is only used at executor side. - private[this] var evalAggregateBuffer: InputAggregationBuffer = null - - /** - * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of - * `inputAggregateBuffer` based on this new inputBufferOffset. - */ - override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { - super.withNewInputBufferOffset(newInputBufferOffset) - // inputBufferOffset has been updated. - inputAggregateBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - inputBufferOffset, - null) + private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } - /** - * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of - * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. - */ - override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { - super.withNewMutableBufferOffset(newMutableBufferOffset) - // mutableBufferOffset has been updated. - mutableAggregateBuffer = - new MutableAggregationBufferImpl( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) - evalAggregateBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) + // This buffer is only used at executor side. + private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } override def initialize(buffer: MutableRow): Unit = { @@ -424,7 +443,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 960be08f84d9..83379ae90f70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -25,206 +25,240 @@ import org.apache.spark.sql.execution.SparkPlan * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { + + def planAggregateWithoutPartial( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = completeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + + SortBasedAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + child = child + ) :: Nil + } + + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + // 1. Create an Aggregate Operator for partial aggregations. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + resultExpressions = partialResultExpressions, child = child) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) - } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val finalAggregate = - Aggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = namedGroupingAttributes.length, - resultExpressions = rewrittenResultExpressions, + + val finalAggregate = createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, child = partialAggregate) finalAggregate :: Nil } def planAggregateWithOneDistinct( - groupingExpressions: Seq[Expression], - functionsWithDistinct: Seq[AggregateExpression2], - functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + groupingExpressions: Seq[NamedExpression], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // 1. Create an Aggregate Operator for partial aggregations. - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - // It is safe to call head at here since functionsWithDistinct has at least one - // AggregateExpression2. - val distinctColumnExpressions = - functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { - case ne: NamedExpression => ne -> ne - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one + // DISTINCT aggregate function, all of those functions will have the same column expressions. + // For example, it would be valid for functionsWithDistinct to be + // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is + // disallowed because those two distinct aggregates have different column expressions. + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() } - val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) + val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Partial, false) - } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregateGroupingExpressions = - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) - val partialAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes - val partialAggregate = - Aggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, + // 1. Create an Aggregate Operator for partial aggregations. + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = child) + } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, PartialMerge, false) - } - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes + val partialMergeAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val partialMergeAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes - val partialMergeAggregate = - Aggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate) - - // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Final, false) } - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) - } - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if distinctColumnExpressionMap.contains(expr) => - distinctColumnExpressionMap(expr).toAttribute - }.asInstanceOf[AggregateFunction2] - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, false) - - val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) - (rewrittenAggregateExpression -> aggregateFunctionAttribute) - }.unzip - - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + case agg @ AggregateExpression(aggregateFunction, mode, true) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - val finalAndCompleteAggregate = - Aggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = rewrittenResultExpressions, + + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) + } finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 5a1b000e8987..b3e4688557ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,43 +17,23 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator -import org.apache.spark.util.{CompletionIterator, MutablePair} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.MutablePair +import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map(reusableProjection) - } - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering -} +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { -/** - * A variant of [[Project]] that returns [[UnsafeRow]]s. - */ -case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -61,28 +41,41 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitionsInternal { iter => + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) + iter.map { row => + numRows += 1 + project(row) + } } - val project = UnsafeProjection.create(projectList, child.output) - iter.map(project) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering } -/** - * :: DeveloperApi :: - */ -@DeveloperApi case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - iter.filter(newPredicate(condition, child.output)) + private[sql] override lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + protected override def doExecute(): RDD[InternalRow] = { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsInternal { iter => + val predicate = newPredicate(condition, child.output) + iter.filter { row => + numInputRows += 1 + val r = predicate(row) + if (r) numOutputRows += 1 + r + } + } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -95,16 +88,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } /** - * :: DeveloperApi :: * Sample the dataset. + * * @param lowerBound Lower-bound of the sampling probability (usually 0.0) * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ -@DeveloperApi case class Sample( lowerBound: Double, upperBound: Double, @@ -115,24 +107,37 @@ case class Sample( { override def output: Seq[Attribute] = child.output - // TODO: How to pick seed? + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { - child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + new PartitionwiseSampledRDD[InternalRow, InternalRow]( + child.execute(), + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + preservesPartitioning = true, + seed) } else { - child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } } /** - * :: DeveloperApi :: + * Union two plans, without a distinct. This is UNION ALL in SQL. */ -@DeveloperApi case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output - override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = @@ -140,14 +145,12 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } /** - * :: DeveloperApi :: * Take the first limit elements. Note that the implementation is different depending on whether * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, * this operator uses something similar to Spark's take method on the Spark driver. If it is not * terminal or is invoked using execute, we first take the limit on each partition, and then * repartition all the data to a single partition to compute the global limit. */ -@DeveloperApi case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: @@ -159,15 +162,15 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def executeCollect(): Array[Row] = child.executeTake(limit) + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) protected override def doExecute(): RDD[InternalRow] = { val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.take(limit).map(row => (false, row.copy())) } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val mutablePair = new MutablePair[Boolean, InternalRow]() iter.take(limit).map(row => mutablePair.update(false, row)) } @@ -175,30 +178,33 @@ case class Limit(limit: Int, child: SparkPlan) val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitions(_.take(limit).map(_._2)) + shuffled.mapPartitionsInternal(_.take(limit).map(_._2)) } } /** - * :: DeveloperApi :: * Take the first limit elements as defined by the sortOrder, and do projection if needed. * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, * or having a [[Project]] operator between them. * This could have been named TopK, but Spark's top operator does the opposite in ordering * so we name it TakeOrdered to avoid confusion. */ -@DeveloperApi case class TakeOrderedAndProject( limit: Int, sortOrder: Seq[SortOrder], projectList: Option[Seq[NamedExpression]], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } override def outputPartitioning: Partitioning = SinglePartition - private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) @@ -208,9 +214,8 @@ case class TakeOrderedAndProject( projection.map(data.map(_)).getOrElse(data) } - override def executeCollect(): Array[Row] = { - val converter = CatalystTypeConverters.createToScalaConverter(schema) - collectData().map(converter(_).asInstanceOf[Row]) + override def executeCollect(): Array[InternalRow] = { + collectData() } // TODO: Terminal split should be implemented differently from non-terminal split. @@ -218,60 +223,202 @@ case class TakeOrderedAndProject( protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } } /** - * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. */ -@DeveloperApi -case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) - extends UnaryNode { +case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = { + if (numPartitions == 1) SinglePartition + else UnknownPartitioning(numPartitions) + } + protected override def doExecute(): RDD[InternalRow] = { - child.execute().map(_.copy()).coalesce(numPartitions, shuffle) + child.execute().coalesce(numPartitions, shuffle = false) } -} + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true +} /** - * :: DeveloperApi :: * Returns a table with the elements from left that are not in right using * the built-in spark subtract function. */ -@DeveloperApi case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** - * :: DeveloperApi :: * Returns the rows in left that also appear in right using the built in spark * intersection function. */ -@DeveloperApi case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = children.head.output protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** - * :: DeveloperApi :: * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already * resolved tree. */ -@DeveloperApi case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil protected override def doExecute(): RDD[InternalRow] = child.execute() } + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val tBoundEncoder = tEncoder.bind(child.output) + func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) + } + } +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns[T, U]( + func: T => U, + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + newColumns: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + // We are using an unsafe combiner. + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + + override def output: Seq[Attribute] = child.output ++ newColumns + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val tBoundEncoder = tEncoder.bind(child.output) + val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) + iter.map { row => + val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val groupKeyEncoder = kEncoder.bind(groupingAttributes) + val groupDataEncoder = tEncoder.bind(child.output) + + grouped.flatMap { case (key, rowIter) => + val result = func( + groupKeyEncoder.fromRow(key), + rowIter.map(groupDataEncoder.fromRow)) + result.map(uEncoder.toRow) + } + } + } +} + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + resultEnc: ExpressionEncoder[Result], + output: Seq[Attribute], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala new file mode 100644 index 000000000000..fee36f602389 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.types._ + +/** + * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is + * extracted from the buffer, instead of directly returning it, the value is set into some field of + * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods + * for primitive values provided by [[MutableRow]]. + */ +private[columnar] trait ColumnAccessor { + initialize() + + protected def initialize() + + def hasNext: Boolean + + def extractTo(row: MutableRow, ordinal: Int) + + protected def underlyingBuffer: ByteBuffer +} + +private[columnar] abstract class BasicColumnAccessor[JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[JvmType]) + extends ColumnAccessor { + + protected def initialize() {} + + override def hasNext: Boolean = buffer.hasRemaining + + override def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) + } + + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } + + protected def underlyingBuffer = buffer +} + +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Any](buffer, NULL) + with NullableColumnAccessor + +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( + override protected val buffer: ByteBuffer, + override protected val columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + with CompressibleColumnAccessor[T] + +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BOOLEAN) + +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) + +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, SHORT) + +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, LONG) + +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, FLOAT) + +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) + +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, STRING) + +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) + with NullableColumnAccessor + +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) + +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) + with NullableColumnAccessor + +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) + with NullableColumnAccessor + +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) + with NullableColumnAccessor + +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) + with NullableColumnAccessor + +private[columnar] object ColumnAccessor { + def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { + val buf = buffer.order(ByteOrder.nativeOrder) + + dataType match { + case NullType => new NullColumnAccessor(buf) + case BooleanType => new BooleanColumnAccessor(buf) + case ByteType => new ByteColumnAccessor(buf) + case ShortType => new ShortColumnAccessor(buf) + case IntegerType | DateType => new IntColumnAccessor(buf) + case LongType | TimestampType => new LongColumnAccessor(buf) + case FloatType => new FloatColumnAccessor(buf) + case DoubleType => new DoubleColumnAccessor(buf) + case StringType => new StringColumnAccessor(buf) + case BinaryType => new BinaryColumnAccessor(buf) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnAccessor(buf, dt) + case dt: DecimalType => new DecimalColumnAccessor(buf, dt) + case struct: StructType => new StructColumnAccessor(buf, struct) + case array: ArrayType => new ArrayColumnAccessor(buf, array) + case map: MapType => new MapColumnAccessor(buf, map) + case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer) + case other => + throw new Exception(s"not support type: $other") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 1620fc401ba6..7e26f19bb744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ -private[sql] trait ColumnBuilder { +private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ @@ -46,7 +46,7 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[JvmType]( +private[columnar] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, val columnType: ColumnType[JvmType]) extends ColumnBuilder { @@ -63,9 +63,8 @@ private[sql] class BasicColumnBuilder[JvmType]( val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName - // Reserves 4 bytes for column type ID - buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize) - buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) + buffer = ByteBuffer.allocate(size * columnType.defaultSize) + buffer.order(ByteOrder.nativeOrder()) } override def appendFrom(row: InternalRow, ordinal: Int): Unit = { @@ -74,17 +73,28 @@ private[sql] class BasicColumnBuilder[JvmType]( } override def build(): ByteBuffer = { + if (buffer.capacity() > buffer.position() * 1.1) { + // trim the buffer + buffer = ByteBuffer + .allocate(buffer.position()) + .order(ByteOrder.nativeOrder()) + .put(buffer.array(), 0, buffer.position()) + } buffer.flip().asInstanceOf[ByteBuffer] } } -private[sql] abstract class ComplexColumnBuilder[JvmType]( +private[columnar] class NullColumnBuilder + extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + with NullableColumnBuilder + +private[columnar] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T#InternalType](columnStats, columnType) @@ -92,42 +102,47 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) + +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) + extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -private[sql] class FixedDecimalColumnBuilder( - precision: Int, - scale: Int) - extends NativeColumnBuilder( - new FixedDecimalColumnStats(precision, scale), - FIXED_DECIMAL(precision, scale)) +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) + extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder(dataType: DataType) - extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType)) +private[columnar] class StructColumnBuilder(dataType: StructType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) -private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) -private[sql] class TimestampColumnBuilder - extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) +private[columnar] class MapColumnBuilder(dataType: MapType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) -private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 +private[columnar] object ColumnBuilder { + val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 + val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { @@ -135,7 +150,7 @@ private[sql] object ColumnBuilder { } else { // grow in steps of initial size val capacity = orig.capacity() - val newSize = capacity + size.max(capacity / 8 + 1) + val newSize = capacity + size.max(capacity) val pos = orig.position() ByteBuffer @@ -151,20 +166,26 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case NullType => new NullColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder case ShortType => new ShortColumnBuilder - case IntegerType => new IntColumnBuilder - case DateType => new DateColumnBuilder - case LongType => new LongColumnBuilder - case TimestampType => new TimestampColumnBuilder + case IntegerType | DateType => new IntColumnBuilder + case LongType | TimestampType => new LongColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DecimalType.Fixed(precision, scale) if precision < 19 => - new FixedDecimalColumnBuilder(precision, scale) - case other => new GenericColumnBuilder(other) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnBuilder(dt) + case dt: DecimalType => new DecimalColumnBuilder(dt) + case struct: StructType => new StructColumnBuilder(struct) + case array: ArrayType => new ArrayColumnBuilder(array) + case map: MapType => new MapColumnBuilder(map) + case udt: UserDefinedType[_] => + return apply(udt.sqlType, initialSize, columnName, useCompression) + case other => + throw new Exception(s"not suppported type: $other") } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 67% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index af1a8ecca9b5..c52ee9ffd6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - protected var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -66,19 +66,20 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: InternalRow + def collectedStatistics: GenericInternalRow } /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -92,11 +93,11 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -110,11 +111,11 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -128,11 +129,11 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -146,11 +147,11 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -164,11 +165,11 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -182,11 +183,11 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -200,11 +201,11 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -212,17 +213,17 @@ private[sql] class StringColumnStats extends ColumnStats { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getUTF8String(ordinal) - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() sizeInBytes += STRING.actualSize(row, ordinal) } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -230,11 +231,13 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { + def this(dt: DecimalType) = this(dt.precision, dt.scale) + protected var upper: Decimal = null protected var lower: Decimal = null @@ -244,16 +247,17 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize + // TODO: this is not right for DecimalType with precision > 18 + sizeInBytes += 8 } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { - val columnType = GENERIC(dataType) +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { + val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -262,10 +266,6 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } - -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala new file mode 100644 index 000000000000..c9f2329db4b6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -0,0 +1,689 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.UTF8String + + +/** + * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. + * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * + * WARNNING: This only works with HeapByteBuffer + */ +private[columnar] object ByteBufferHelper { + def getInt(buffer: ByteBuffer): Int = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getLong(buffer: ByteBuffer): Long = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getFloat(buffer: ByteBuffer): Float = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getDouble(buffer: ByteBuffer): Double = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } +} + +/** + * An abstract class that represents type of a column. Used to append/extract Java objects into/from + * the underlying [[ByteBuffer]] of a column. + * + * @tparam JvmType Underlying Java type to represent the elements. + */ +private[columnar] sealed abstract class ColumnType[JvmType] { + + // The catalyst data type of this column. + def dataType: DataType + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int + + /** + * Extracts a value out of the buffer at the buffer's current position. + */ + def extract(buffer: ByteBuffer): JvmType + + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + + /** + * Appends the given value v of type T into the given ByteBuffer. + */ + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } + + /** + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. + */ + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize + + /** + * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs + * whenever possible. + */ + def getField(row: InternalRow, ordinal: Int): JvmType + + /** + * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing + * costs whenever possible. + */ + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + /** + * Creates a duplicated copy of the value. + */ + def clone(v: JvmType): JvmType = v + + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} + +private[columnar] object NULL extends ColumnType[Any] { + + override def dataType: DataType = NullType + override def defaultSize: Int = 0 + override def append(v: Any, buffer: ByteBuffer): Unit = {} + override def extract(buffer: ByteBuffer): Any = null + override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Any = null +} + +private[columnar] abstract class NativeColumnType[T <: AtomicType]( + val dataType: T, + val defaultSize: Int) + extends ColumnType[T#InternalType] { + + /** + * Scala TypeTag. Can be used to create primitive arrays and hash tables. + */ + def scalaTag: TypeTag[dataType.InternalType] = dataType.tag +} + +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + + override def extract(buffer: ByteBuffer): Int = { + ByteBufferHelper.getInt(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { + row.setInt(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } +} + +private[columnar] object LONG extends NativeColumnType(LongType, 8) { + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + + override def extract(buffer: ByteBuffer): Long = { + ByteBufferHelper.getLong(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { + row.setLong(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } +} + +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { + override def append(v: Float, buffer: ByteBuffer): Unit = { + buffer.putFloat(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + + override def extract(buffer: ByteBuffer): Float = { + ByteBufferHelper.getFloat(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { + row.setFloat(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } +} + +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { + override def append(v: Double, buffer: ByteBuffer): Unit = { + buffer.putDouble(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + + override def extract(buffer: ByteBuffer): Double = { + ByteBufferHelper.getDouble(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { + row.setDouble(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } +} + +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) + } + + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { + row.setBoolean(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } +} + +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { + buffer.put(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + + override def extract(buffer: ByteBuffer): Byte = { + buffer.get() + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { + row.setByte(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } +} + +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { + override def append(v: Short, buffer: ByteBuffer): Unit = { + buffer.putShort(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + + override def extract(buffer: ByteBuffer): Short = { + buffer.getShort() + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { + row.setShort(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } +} + +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) + } else { + super.append(row, ordinal, buffer) + } + } +} + +private[columnar] object STRING + extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getUTF8String(ordinal).numBytes() + 4 + } + + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + buffer.putInt(v.numBytes()) + v.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UTF8String = { + val length = buffer.getInt() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) + } + + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) + } else { + row.update(ordinal, value.clone()) + } + } + + override def getField(row: InternalRow, ordinal: Int): UTF8String = { + row.getUTF8String(ordinal) + } + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + override def clone(v: UTF8String): UTF8String = v.clone() +} + +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) + extends NativeColumnType(DecimalType(precision, scale), 8) { + + override def extract(buffer: ByteBuffer): Decimal = { + Decimal(ByteBufferHelper.getLong(buffer), precision, scale) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + override def append(v: Decimal, buffer: ByteBuffer): Unit = { + buffer.putLong(v.toUnscaledLong) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } +} + +private[columnar] object COMPACT_DECIMAL { + def apply(dt: DecimalType): COMPACT_DECIMAL = { + COMPACT_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) + extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { + + def serialize(value: JvmType): Array[Byte] + def deserialize(bytes: Array[Byte]): JvmType + + override def append(v: JvmType, buffer: ByteBuffer): Unit = { + val bytes = serialize(v) + buffer.putInt(bytes.length).put(bytes, 0, bytes.length) + } + + override def extract(buffer: ByteBuffer): JvmType = { + val length = buffer.getInt() + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + deserialize(bytes) + } +} + +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { + + def dataType: DataType = BinaryType + + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { + row.getBinary(ordinal) + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getBinary(ordinal).length + 4 + } + + def serialize(value: Array[Byte]): Array[Byte] = value + def deserialize(bytes: Array[Byte]): Array[Byte] = bytes +} + +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) + extends ByteArrayColumnType[Decimal](12) { + + override val dataType: DataType = DecimalType(precision, scale) + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + + override def serialize(value: Decimal): Array[Byte] = { + value.toJavaBigDecimal.unscaledValue().toByteArray + } + + override def deserialize(bytes: Array[Byte]): Decimal = { + val javaDecimal = new BigDecimal(new BigInteger(bytes), scale) + Decimal.apply(javaDecimal, precision, scale) + } +} + +private[columnar] object LARGE_DECIMAL { + def apply(dt: DecimalType): LARGE_DECIMAL = { + LARGE_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] case class STRUCT(dataType: StructType) + extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { + + private val numOfFields: Int = dataType.fields.size + + override def defaultSize: Int = 20 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeRow = { + row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).getSizeInBytes + } + + override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeRow = { + val sizeInBytes = ByteBufferHelper.getInt(buffer) + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + sizeInBytes) + val unsafeRow = new UnsafeRow + unsafeRow.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numOfFields, + sizeInBytes) + unsafeRow + } + + override def clone(v: UnsafeRow): UnsafeRow = v.copy() +} + +private[columnar] case class ARRAY(dataType: ArrayType) + extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { + + override def defaultSize: Int = 16 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = { + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeArray = getField(row, ordinal) + 4 + unsafeArray.getSizeInBytes + } + + override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeArrayData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + val array = new UnsafeArrayData + array.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + array + } + + override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() +} + +private[columnar] case class MAP(dataType: MapType) + extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { + + override def defaultSize: Int = 32 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = { + row.getMap(ordinal).asInstanceOf[UnsafeMapData] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeMap = getField(row, ordinal) + 4 + unsafeMap.getSizeInBytes + } + + override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeMapData = { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + val map = new UnsafeMapData + map.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + map + } + + override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() +} + +private[columnar] object ColumnType { + def apply(dataType: DataType): ColumnType[_] = { + dataType match { + case NullType => NULL + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT + case IntegerType | DateType => INT + case LongType | TimestampType => LONG + case FloatType => FLOAT + case DoubleType => DOUBLE + case StringType => STRING + case BinaryType => BINARY + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt) + case dt: DecimalType => LARGE_DECIMAL(dt) + case arr: ArrayType => ARRAY(arr) + case map: MapType => MAP(map) + case struct: StructType => STRUCT(struct) + case udt: UserDefinedType[_] => apply(udt.sqlType) + case other => + throw new Exception(s"Unsupported type: $other") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala new file mode 100644 index 000000000000..eaafc96e4d2e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.types._ + +/** + * An Iterator to walk through the InternalRows from a CachedBatch + */ +abstract class ColumnarIterator extends Iterator[InternalRow] { + def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType], + columnIndexes: Array[Int]): Unit +} + +/** + * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNING: These setter MUST be called in increasing order of ordinals. + */ +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { + + override def isNullAt(i: Int): Boolean = writer.isNullAt(i) + override def setNullAt(i: Int): Unit = writer.setNullAt(i) + + override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v) + override def setByte(i: Int, v: Byte): Unit = writer.write(i, v) + override def setShort(i: Int, v: Short): Unit = writer.write(i, v) + override def setInt(i: Int, v: Int): Unit = writer.write(i, v) + override def setLong(i: Int, v: Long): Unit = writer.write(i, v) + override def setFloat(i: Int, v: Float): Unit = writer.write(i, v) + override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) + + // the writer will be used directly to avoid creating wrapper objects + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = + throw new UnsupportedOperationException + override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException + + // all other methods inherited from GenericMutableRow are not need +} + +/** + * Generates bytecode for an [[ColumnarIterator]] for columnar cache. + */ +object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { + + protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in + protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in + + protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { + val ctx = newCodeGenContext() + val numFields = columnTypes.size + val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => + val accessorName = ctx.freshName("accessor") + val accessorCls = dt match { + case NullType => classOf[NullColumnAccessor].getName + case BooleanType => classOf[BooleanColumnAccessor].getName + case ByteType => classOf[ByteColumnAccessor].getName + case ShortType => classOf[ShortColumnAccessor].getName + case IntegerType | DateType => classOf[IntColumnAccessor].getName + case LongType | TimestampType => classOf[LongColumnAccessor].getName + case FloatType => classOf[FloatColumnAccessor].getName + case DoubleType => classOf[DoubleColumnAccessor].getName + case StringType => classOf[StringColumnAccessor].getName + case BinaryType => classOf[BinaryColumnAccessor].getName + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + classOf[CompactDecimalColumnAccessor].getName + case dt: DecimalType => classOf[DecimalColumnAccessor].getName + case struct: StructType => classOf[StructColumnAccessor].getName + case array: ArrayType => classOf[ArrayColumnAccessor].getName + case t: MapType => classOf[MapColumnAccessor].getName + } + ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") + + val createCode = dt match { + case t if ctx.isPrimitiveType(dt) => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case NullType | StringType | BinaryType => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case other => + s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" + } + + val extract = s"$accessorName.extractTo(mutableRow, $index);" + val patch = dt match { + case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => + // For large Decimal, it should have 16 bytes for future update even it's null now. + s""" + if (mutableRow.isNullAt($index)) { + rowWriter.write($index, (Decimal) null, $p, $s); + } + """ + case other => "" + } + (createCode, extract + patch) + }.unzip + + val code = s""" + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; + + public SpecificColumnarIterator generate($exprType[] expr) { + return new SpecificColumnarIterator(); + } + + class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} { + + private ByteOrder nativeOrder = null; + private byte[][] buffers = null; + private UnsafeRow unsafeRow = new UnsafeRow(); + private BufferHolder bufferHolder = new BufferHolder(); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private MutableUnsafeRow mutableRow = null; + + private int currentRow = 0; + private int numRowsInBatch = 0; + + private scala.collection.Iterator input = null; + private DataType[] columnTypes = null; + private int[] columnIndexes = null; + + ${declareMutableStates(ctx)} + + public SpecificColumnarIterator() { + this.nativeOrder = ByteOrder.nativeOrder(); + this.buffers = new byte[${columnTypes.length}][]; + this.mutableRow = new MutableUnsafeRow(rowWriter); + + ${initMutableStates(ctx)} + } + + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { + this.input = input; + this.columnTypes = columnTypes; + this.columnIndexes = columnIndexes; + } + + public boolean hasNext() { + if (currentRow < numRowsInBatch) { + return true; + } + if (!input.hasNext()) { + return false; + } + + ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); + currentRow = 0; + numRowsInBatch = batch.numRows(); + for (int i = 0; i < columnIndexes.length; i ++) { + buffers[i] = batch.buffers()[columnIndexes[i]]; + } + ${initializeAccessors.mkString("\n")} + + return hasNext(); + } + + public InternalRow next() { + currentRow += 1; + bufferHolder.reset(); + rowWriter.initialize(bufferHolder, $numFields); + ${extractors.mkString("\n")} + unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + return unsafeRow; + } + }""" + + logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") + + compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 5d5b0697d701..3c5a8cb2aa93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,19 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar - -import java.nio.ByteBuffer +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -38,20 +39,30 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, + if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), + tableName)() } -private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) +/** + * CachedBatch is a cached batch of rows. + * + * @param numRows The total number of rows in this batch + * @param buffers The buffers for serialized columns + * @param stats The stat of columns + */ +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +73,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -116,17 +127,17 @@ private[sql] case class InMemoryRelation( private def buildBuffers(): Unit = { val output = child.output - val cached = child.execute().mapPartitions { rowIterator => + val cached = child.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => - val columnType = ColumnType(attribute.dataType) - val initialBufferSize = columnType.defaultSize * batchSize - ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression) + ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) }.toArray var rowCount = 0 - while (rowIterator.hasNext && rowCount < batchSize) { + var totalSize = 0L + while (rowIterator.hasNext && rowCount < batchSize + && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { val row = rowIterator.next() // Added for SPARK-6082. This assertion can be useful for scenarios when something @@ -140,18 +151,22 @@ private[sql] case class InMemoryRelation( s"\nRow content: $row") var i = 0 + totalSize = 0 while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) + totalSize += columnBuilders(i).columnStats.sizeInBytes i += 1 } rowCount += 1 } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.toSeq)) + .flatMap(_.values)) batchStats += stats - CachedBatch(columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) } def hasNext: Boolean = rowIterator.hasNext @@ -198,16 +213,24 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes + // The cached version does not change the outputPartitioning of the original SparkPlan. + override def outputPartitioning: Partitioning = relation.child.outputPartitioning + + // The cached version does not change the outputOrdering of the original SparkPlan. + override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + + override def outputsUnsafeRows: Boolean = true + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -270,70 +293,34 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitionsInternal { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) - - // Find the ordinals and data types of the requested columns. If none are requested, use the - // narrowest (the field with minimum default element size). - val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { - val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => - ordinal -> a.dataType - } minBy { case (_, dataType) => - ColumnType(dataType).defaultSize - } - Seq(narrowestOrdinal) -> Seq(narrowestDataType) - } else { + schema) + + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip - } - - val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = { - val rows = cacheBatches.flatMap { cachedBatch => - // Build column accessors - val columnAccessors = requestedColumnIndices.map { batchColumnIndex => - ColumnAccessor( - relation.output(batchColumnIndex).dataType, - ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) - } - - // Extract rows via column accessors - new Iterator[InternalRow] { - private[this] val rowLen = nextRow.numFields - override def next(): InternalRow = { - var i = 0 - while (i < rowLen) { - columnAccessors(i).extractTo(nextRow, i) - i += 1 - } - if (attributes.isEmpty) InternalRow.empty else nextRow - } - - override def hasNext: Boolean = columnAccessors(0).hasNext - } - } - - if (rows.hasNext && enableAccumulators) { - readPartitions += 1 - } - - rows - } // Do partition batch pruning if enabled val cachedBatchesToScan = if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema - .zip(cachedBatch.stats.toSeq) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") + def statsString: String = schemaIndex.map { + case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { @@ -347,7 +334,16 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } - cachedBatchesToRows(cachedBatchesToScan) + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulators && columnarIterator.hasNext) { + readPartitions += 1 + } + columnarIterator } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 4d35650d4b1e..8d99546924de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 @@ -31,8 +31,8 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) - nullCount = nullsBuffer.getInt() - nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 + nullCount = ByteBufferHelper.getInt(nullsBuffer) + nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 pos = 0 underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4) @@ -44,7 +44,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { seenNulls += 1 if (seenNulls < nullCount) { - nextNullIndex = nullsBuffer.getInt() + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) } row.setNullAt(ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 79% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index ba47bc783f31..3a1931bfb5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -25,17 +25,16 @@ import org.apache.spark.sql.catalyst.InternalRow * A stackable trait used for building byte buffer for a column containing null values. Memory * layout of the final byte buffer is: * {{{ - * .----------------------- Column type ID (4 bytes) - * | .------------------- Null count N (4 bytes) - * | | .--------------- Null positions (4 x N bytes, empty if null count is zero) - * | | | .--------- Non-null elements - * V V V V - * +---+---+-----+---------+ - * | | | ... | ... ... | - * +---+---+-----+---------+ + * .------------------- Null count N (4 bytes) + * | .--------------- Null positions (4 x N bytes, empty if null count is zero) + * | | .--------- Non-null elements + * V V V + * +---+-----+---------+ + * | | ... | ... ... | + * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ @@ -66,16 +65,14 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { abstract override def build(): ByteBuffer = { val nonNulls = super.build() - val typeId = nonNulls.getInt() val nullDataLen = nulls.position() nulls.limit(nullDataLen) nulls.rewind() val buffer = ByteBuffer - .allocate(4 + 4 + nullDataLen + nonNulls.remaining()) + .allocate(4 + nullDataLen + nonNulls.remaining()) .order(ByteOrder.nativeOrder()) - .putInt(typeId) .putInt(nullCount) .put(nulls) .put(nonNulls) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1a..6579b5068e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 39b21ddb47ba..b0e216feb559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,33 +15,32 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of * the final byte buffer is: * {{{ - * .--------------------------- Column type ID (4 bytes) - * | .----------------------- Null count N (4 bytes) - * | | .------------------- Null positions (4 x N bytes, empty if null count is zero) - * | | | .------------- Compression scheme ID (4 bytes) - * | | | | .--------- Compressed non-null elements - * V V V V V - * +---+---+-----+---+---------+ - * | | | ... | | ... ... | - * +---+---+-----+---+---------+ - * \-----------/ \-----------/ - * header body + * .----------------------- Null count N (4 bytes) + * | .------------------- Null positions (4 x N bytes, empty if null count is zero) + * | | .------------- Compression scheme ID (4 bytes) + * | | | .--------- Compressed non-null elements + * V V V V + * +---+-----+---+---------+ + * | | ... | | ... ... | + * +---+-----+---+---------+ + * \-------/ \-----------/ + * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => @@ -83,14 +82,13 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] override def build(): ByteBuffer = { val nonNullBuffer = buildNonNulls() - val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType) } - // Header = column type ID + null count + null positions - val headerSize = 4 + 4 + nulls.limit() + // Header = null count + null positions + val headerSize = 4 + nulls.limit() val compressedSize = if (encoder.compressedSize == 0) { nonNullBuffer.remaining() } else { @@ -102,7 +100,6 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] .allocate(headerSize + 4 + compressedSize) .order(ByteOrder.nativeOrder) // Write the header - .putInt(typeId) .putInt(nullCount) .put(nulls) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index b1ef9b2ef784..920381f9c63d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +37,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +53,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) @@ -74,8 +74,8 @@ private[sql] object CompressionScheme { def columnHeaderSize(columnBuffer: ByteBuffer): Int = { val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) - val nullCount = header.getInt(4) - // Column type ID + null count + null positions - 4 + 4 + 4 * nullCount + val nullCount = header.getInt() + // null count + null positions + 4 + 4 * nullCount } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index c91d960a0932..941f03b745a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,21 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import scala.collection.mutable -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.runtimeMirror + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -66,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -161,7 +159,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) - run = buffer.getInt() + run = ByteBufferHelper.getInt(buffer) valueCount = 1 } else { valueCount += 1 @@ -174,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -270,27 +268,20 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - private val dictionary = { - // TODO Can we clean up this mess? Maybe move this to `DataType`? - implicit val classTag = { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) - } - - Array.fill(buffer.getInt()) { - columnType.extract(buffer) - } + private val dictionary: Array[Any] = { + val elementNum = ByteBufferHelper.getInt(buffer) + Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } override def next(row: MutableRow, ordinal: Int): Unit = { - columnType.setField(row, ordinal, dictionary(buffer.getShort())) + columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) } override def hasNext: Boolean = buffer.hasRemaining } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -359,7 +350,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] { - private val count = buffer.getInt() + private val count = ByteBufferHelper.getInt(buffer) private var currentWord = 0: Long @@ -370,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { visited += 1 if (bit == 0) { - currentWord = buffer.getLong() + currentWord = ByteBufferHelper.getLong(buffer) } row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) @@ -380,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -454,13 +445,13 @@ private[sql] case object IntDelta extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -534,7 +525,7 @@ private[sql] case object LongDelta extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 6b83025d5a15..24a79f289aa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution import java.util.NoSuchElementException import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, Expression, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -54,27 +53,27 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow]) + } override def output: Seq[Attribute] = cmd.output override def children: Seq[SparkPlan] = Nil - override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray - override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray protected override def doExecute(): RDD[InternalRow] = { - val convert = CatalystTypeConverters.createToCatalystConverter(schema) - val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow]) - sqlContext.sparkContext.parallelize(converted, 1) + sqlContext.sparkContext.parallelize(sideEffectResult, 1) } + + override def argString: String = cmd.toString } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { private def keyValueOutput: Seq[Attribute] = { @@ -103,6 +102,63 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + + s"continue to be true.") + Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " + + s"will be ignored. Tungsten will continue to be used.") + Seq(Row(SQLConf.Deprecated.TUNGSTEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " + + s"will be ignored. Codegen will continue to be used.") + Seq(Row(SQLConf.Deprecated.CODEGEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " + + s"will be ignored. Unsafe mode will continue to be used.") + Seq(Row(SQLConf.Deprecated.UNSAFE_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.SORTMERGE_JOIN} is deprecated and " + + s"will be ignored. Sort merge join will continue to be used.") + Seq(Row(SQLConf.Deprecated.SORTMERGE_JOIN, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { @@ -148,7 +204,11 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm val runFunc = (sqlContext: SQLContext) => { val value = try { - sqlContext.getConf(key) + if (key == SQLConf.DIALECT.key) { + sqlContext.conf.dialect + } else { + sqlContext.getConf(key) + } } catch { case _: NoSuchElementException => "" } @@ -168,10 +228,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm * * Note that this command takes in a logical plan, runs the optimizer on the logical plan * (but do NOT actually execute it). - * - * :: DeveloperApi :: */ -@DeveloperApi case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = @@ -191,10 +248,7 @@ case class ExplainCommand( } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class CacheTableCommand( tableName: String, plan: Option[LogicalPlan], @@ -219,10 +273,6 @@ case class CacheTableCommand( } -/** - * :: DeveloperApi :: - */ -@DeveloperApi case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -234,10 +284,8 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { } /** - * :: DeveloperApi :: * Clear all cached data from the in-memory cache. */ -@DeveloperApi case object ClearCacheCommand extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -248,10 +296,7 @@ case object ClearCacheCommand extends RunnableCommand { override def output: Seq[Attribute] = Seq.empty } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class DescribeCommand( child: SparkPlan, override val output: Seq[Attribute], @@ -274,9 +319,7 @@ case class DescribeCommand( * {{{ * SHOW TABLES [IN databaseName] * }}} - * :: DeveloperApi :: */ -@DeveloperApi case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala new file mode 100644 index 000000000000..f22508b21090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -0,0 +1,186 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import scala.language.implicitConversions +import scala.util.matching.Regex + +import org.apache.spark.Logging +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.types._ + + +/** + * A parser for foreign DDL commands. + */ +class DDLParser(parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { + + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { + try { + parse(input) + } catch { + case ddlException: DDLException => throw ddlException + case _ if !exceptionOnError => parseQuery(input) + case x: Throwable => throw x + } + } + + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") + + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable + + protected def start: Parser[LogicalPlan] = ddl + + /** + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... + */ + protected lazy val createTable: Parser[LogicalPlan] = { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableIdent, + provider, + temp.isDefined, + Array.empty[String], + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableIdent, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + } + } + + // This is the same as tableIdentifier in SqlParser. + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { + case e ~ tableIdent => + DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> tableIdentifier ^^ { + case tableIndet => + RefreshTable(tableIndet) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex $regex", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k, v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } + + StructField(columnName, typ, nullable = true, meta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6b91e51ca52f..8a15a51d825e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,40 +17,46 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.{Logging, TaskContext} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _)) => pruneFilterProjectRaw( l, projects, filters, - (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil + (requestedColumns, allPredicates, _) => + toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _)) => pruneFilterProject( l, projects, filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _)) => pruneFilterProject( l, projects, @@ -58,9 +64,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) if t.partitionSpec.partitionColumns.nonEmpty => - val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray + // We divide the filter expressions into 3 parts + val partitionColumns = AttributeSet( + t.partitionColumns.map(c => l.output.find(_.name == c.name).get)) + + // Only pruning the partition keys + val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) + + // Only pushes down predicates that do not reference partition keys. + val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) + + // Predicates with both partition keys and attributes + val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet + + val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray logInfo { val total = t.partitionSpec.partitions.length @@ -69,24 +88,19 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } - // Only pushes down predicates that do not reference partition columns. - val pushedFilters = { - val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet - filters.filter { f => - val referencedColumnNames = f.references.map(_.name).toSet - referencedColumnNames.intersect(partitionColumnNames).isEmpty - } - } - - buildPartitionedTableScan( + val scan = buildPartitionedTableScan( l, projects, pushedFilters, t.partitionSpec.partitionColumns, - selectedPartitions) :: Nil + selectedPartitions) + + combineFilters + .reduceLeftOption(expressions.And) + .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => // See buildPartitionedTableScan for the reason that we need to create a shard // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf @@ -96,18 +110,18 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l, projects, filters, - (a, f) => - toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil + (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil - case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil + case l @ LogicalRelation(baseRelation: TableScan, _) => + execution.PhysicalRDD.createFromDataSource( + l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil - case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => + case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _), + part, query, overwrite, false) if part.isEmpty => execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) => + l @ LogicalRelation(t: HadoopFsRelation, _), part, query, overwrite, false) => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil @@ -119,7 +133,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { projections: Seq[NamedExpression], filters: Seq[Expression], partitionColumns: StructType, - partitions: Array[Partition]) = { + partitions: Array[Partition]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] // Because we are creating one RDD per partition, we need to have a shared HadoopConf. @@ -127,92 +141,84 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) + val partitionColumnNames = partitionColumns.fieldNames.toSet + + // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder + // will union all partitions and attach partition values if needed. + val scanBuilder = { + (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { + val requiredDataColumns = + requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) + + // Builds RDD[Row]s for each selected partition. + val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => + // Don't scan any partition columns to save I/O. Here we are being optimistic and + // assuming partition columns data stored in data files are always consistent with those + // partition values encoded in partition directory paths. + val dataRows = relation.buildInternalScan( + requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) + + // Merges data values with partition values. + mergeWithPartitionValues( + requiredColumns, + requiredDataColumns, + partitionColumns, + partitionValues, + dataRows) + } - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // The table scan operator (PhysicalRDD) which retrieves required columns from data files. - // Notice that the schema of data files, represented by `relation.dataSchema`, may contain - // some partition column(s). - val scan = - pruneFilterProject( - logicalRelation, - projections, - filters, - (columns: Seq[Attribute], filters) => { - val partitionColNames = partitionColumns.fieldNames - - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with those - // partition values encoded in partition directory paths. - val needed = columns.filterNot(a => partitionColNames.contains(a.name)) - val dataRows = - relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) - - // Merges data values with partition values. - mergeWithPartitionValues( - relation.schema, - columns.map(_.name).toArray, - partitionColNames, - partitionValues, - toCatalystRDD(logicalRelation, needed, dataRows)) - }) - - scan.execute() - } + val unionedRows = + if (perPartitionRows.length == 0) { + relation.sqlContext.emptyResult + } else { + new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + } - val unionedRows = - if (perPartitionRows.length == 0) { - relation.sqlContext.emptyResult - } else { - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + unionedRows } + } - execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) + // Create the scan operator. If needed, add Filter and/or Project on top of the scan. + // The added Filter/Project is on top of the unioned RDD. We do not want to create + // one Filter/Project for every partition. + val sparkPlan = pruneFilterProject( + logicalRelation, + projections, + filters, + scanBuilder) + + sparkPlan } - // TODO: refactor this thing. It is very complicated because it does projection internally. - // We should just put a project on top of this. private def mergeWithPartitionValues( - schema: StructType, - requiredColumns: Array[String], - partitionColumns: Array[String], + requiredColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + partitionColumnSchema: StructType, partitionValues: InternalRow, dataRows: RDD[InternalRow]): RDD[InternalRow] = { - val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains) - // If output columns contain any partition column(s), we need to merge scanned data // columns and requested partition columns to form the final result. - if (!requiredColumns.sameElements(nonPartitionColumns)) { - val mergers = requiredColumns.zipWithIndex.map { case (name, index) => - // To see whether the `index`-th column is a partition column... - val i = partitionColumns.indexOf(name) - if (i != -1) { - // If yes, gets column value from partition values. - (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues.genericGet(i) - } - } else { - // Otherwise, inherits the value from scanned data. - val i = nonPartitionColumns.indexOf(name) - (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow.genericGet(i) - } + if (requiredColumns != dataColumns) { + // Builds `AttributeReference`s for all partition columns so that we can use them to project + // required partition columns. Note that if a partition column appears in `requiredColumns`, + // we should use the `AttributeReference` in `requiredColumns`. + val partitionColumns = { + val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap + partitionColumnSchema.toAttributes.map { a => + requiredColumnMap.getOrElse(a.name, a) } } - // Since we know for sure that this closure is serializable, we can avoid the overhead - // of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally - // this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718). val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { - val dataTypes = requiredColumns.map(schema(_).dataType) - val mutableRow = new SpecificMutableRow(dataTypes) - iterator.map { dataRow => - var i = 0 - while (i < mutableRow.numFields) { - mergers(i)(mutableRow, dataRow, i) - i += 1 - } - mutableRow.asInstanceOf[InternalRow] + // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and + // `UnsafeProjection`. Because the projection may also adjust column order. + val mutableJoinedRow = new JoinedRow() + val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues) + val unsafeProjection = + UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) + + iterator.map { unsafeDataRow => + unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues)) } } @@ -222,7 +228,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Utils.withDummyCallSite(dataRows.sparkContext) { new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) } - } else { dataRows } @@ -265,45 +270,99 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { relation, projects, filterPredicates, - (requestedColumns, pushedFilters) => { - scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray) + (requestedColumns, _, pushedFilters) => { + scanBuilder(requestedColumns, pushedFilters.toArray) }) } - // Based on Catalyst expressions. + // Based on Catalyst expressions. The `scanBuilder` function accepts three arguments: + // + // 1. A `Seq[Attribute]`, containing all required column attributes. Used to handle relation + // traits that support column pruning (e.g. `PrunedScan` and `PrunedFilteredScan`). + // + // 2. A `Seq[Expression]`, containing all gathered Catalyst filter expressions, only used for + // `CatalystScan`. + // + // 3. A `Seq[Filter]`, containing all data source `Filter`s that are converted from (possibly a + // subset of) Catalyst filter expressions and can be handled by `relation`. Used to handle + // relation traits (`CatalystScan` excluded) that support filter push-down (e.g. + // `PrunedFilteredScan` and `HadoopFsRelation`). + // + // Note that 2 and 3 shouldn't be used together. protected def pruneFilterProjectRaw( - relation: LogicalRelation, - projects: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = { + relation: LogicalRelation, + projects: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]) = { val projectSet = AttributeSet(projects.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(expressions.And) - val pushedFilters = filterPredicates.map { _ transform { + val candidatePredicates = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }} + val (unhandledPredicates, pushedFilters) = + selectFilters(relation.relation, candidatePredicates) + + // A set of column attributes that are only referenced by pushed down filters. We can eliminate + // them from requested columns. + val handledSet = { + val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains) + val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references)) + AttributeSet(handledPredicates.flatMap(_.references)) -- + (projectSet ++ unhandledSet).map(relation.attributeMap) + } + + // Combines all Catalyst filter `Expression`s that are either not convertible to data source + // `Filter`s or cannot be handled by `relation`. + val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + + val metadata: Map[String, String] = { + val pairs = ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) + } + + relation.relation match { + case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.paths.mkString(", ") + case _ => + } + + pairs.toMap + } + if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. - val requestedColumns = - projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. - .map(relation.attributeMap) // Match original case of attributes. - - val scan = execution.PhysicalRDD(projects.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) + val requestedColumns = projects + // Safe due to if above. + .asInstanceOf[Seq[Attribute]] + // Match original case of attributes. + .map(relation.attributeMap) + // Don't request columns that are only referenced by pushed filters. + .filterNot(handledSet.contains) + + val scan = execution.PhysicalRDD.createFromDataSource( + projects.map(_.toAttribute), + scanBuilder(requestedColumns, candidatePredicates, pushedFilters), + relation.relation, metadata) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { - val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - - val scan = execution.PhysicalRDD(requestedColumns, - scanBuilder(requestedColumns, pushedFilters)) - execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + // Don't request columns that are only referenced by pushed filters. + val requestedColumns = + (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq + + val scan = execution.PhysicalRDD.createFromDataSource( + requestedColumns, + scanBuilder(requestedColumns, candidatePredicates, pushedFilters), + relation.relation, metadata) + execution.Project( + projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } @@ -329,38 +388,53 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } /** - * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, - * and convert them. + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ - protected[sql] def selectFilters(filters: Seq[Expression]) = { - def translate(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => - Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) => - Some(sources.EqualTo(a.name, v)) - - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThan(a.name, v)) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => - Some(sources.LessThan(a.name, v)) - - case expressions.LessThan(a: Attribute, Literal(v, _)) => - Some(sources.LessThan(a.name, v)) - case expressions.LessThan(Literal(v, _), a: Attribute) => - Some(sources.GreaterThan(a.name, v)) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, v)) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.LessThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, v)) + protected[sql] def translateFilter(predicate: Expression): Option[Filter] = { + predicate match { + case expressions.EqualTo(a: Attribute, Literal(v, t)) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), a: Attribute) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + + case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + + case expressions.GreaterThan(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), a: Attribute) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + + case expressions.LessThan(a: Attribute, Literal(v, t)) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), a: Attribute) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) case expressions.InSet(a: Attribute, set) => - Some(sources.In(a.name, set.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, set.toArray.map(toScala))) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(EmptyRow)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, hSet.toArray.map(toScala))) case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) @@ -368,16 +442,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.IsNotNull(a.name)) case expressions.And(left, right) => - (translate(left) ++ translate(right)).reduceOption(sources.And) + (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And) case expressions.Or(left, right) => for { - leftFilter <- translate(left) - rightFilter <- translate(right) + leftFilter <- translateFilter(left) + rightFilter <- translateFilter(right) } yield sources.Or(leftFilter, rightFilter) case expressions.Not(child) => - translate(child).map(sources.Not) + translateFilter(child).map(sources.Not) case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringStartsWith(a.name, v.toString)) @@ -390,7 +464,59 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case _ => None } + } + + /** + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s + * and can be handled by `relation`. + * + * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element contains all Catalyst + * predicate [[Expression]]s that are either not convertible or cannot be handled by + * `relation`. The second element contains all converted data source [[Filter]]s that + * will be pushed down to the data source. + */ + protected[sql] def selectFilters( + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = { + + // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are + // called `predicate`s, while all data source filters of type `sources.Filter` are simply called + // `filter`s. + + val translated: Seq[(Expression, Filter)] = + for { + predicate <- predicates + filter <- translateFilter(predicate) + } yield predicate -> filter + + // A map from original Catalyst expressions to corresponding translated data source filters. + val translatedMap: Map[Expression, Filter] = translated.toMap + + // Catalyst predicate expressions that cannot be translated to data source filters. + val unrecognizedPredicates = predicates.filterNot(translatedMap.contains) + + // Data source filters that cannot be handled by `relation`. The semantic of a unhandled filter + // at here is that a data source may not be able to apply this filter to every row + // of the underlying dataset. + val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet + + val (unhandled, handled) = translated.partition { + case (predicate, filter) => + unhandledFilters.contains(filter) + } + + // Catalyst predicate expressions that can be translated to data source filters, but cannot be + // handled by `relation`. + val (unhandledPredicates, _) = unhandled.unzip + + // Translated data source filters that can be handled by `relation` + val (_, handledFilters) = handled.unzip + + // translated contains all filters that have been converted to the public Filter interface. + // We should always push them to the data source no matter whether the data source can apply + // a filter to every row or not. + val (_, translatedFilters) = translated.unzip - filters.flatMap(translate) + (unrecognizedPredicates ++ unhandledPredicates, translatedFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala new file mode 100644 index 000000000000..3b7dc2e8d021 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources.InsertableRelation + + +/** + * Inserts the results of `query` in to a relation that extends [[InsertableRelation]]. + */ +private[sql] case class InsertIntoDataSource( + logicalRelation: LogicalRelation, + query: LogicalPlan, + overwrite: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] + val data = DataFrame(sqlContext, query) + // Apply the schema of the existing table to the new data. + val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + relation.insert(df, overwrite) + + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(logicalRelation) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala new file mode 100644 index 000000000000..735d52f80886 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.spark._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils + + +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ +private[sql] case class InsertIntoHadoopFsRelation( + @transient relation: HadoopFsRelation, + @transient query: LogicalPlan, + mode: SaveMode) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + require( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + + val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val outputPath = new Path(relation.paths.head) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { + case (SaveMode.ErrorIfExists, true) => + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") + case (SaveMode.Overwrite, true) => + Utils.tryOrIOException { + if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $qualifiedOutputPath prior to writing to it") + } + } + true + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") + } + // If we are appending data to an existing dir. + val isAppend = pathExists && (mode == SaveMode.Append) + + if (doInsertion) { + val job = new Job(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, qualifiedOutputPath) + + // A partitioned relation schema's can be different from the input logicalPlan, since + // partition columns are all moved after data column. We Project to adjust the ordering. + // TODO: this belongs in the analyzer. + val project = Project( + relation.schema.map(field => UnresolvedAttribute.quoted(field.name)), query) + val queryExecution = DataFrame(sqlContext, project).queryExecution + + SQLExecution.withNewExecutionId(sqlContext, queryExecution) { + val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) + val partitionColumns = relation.partitionColumns.fieldNames + + // Some pre-flight checks. + require( + df.schema == relation.schema, + s"""DataFrame must have the same schema as the relation to which is inserted. + |DataFrame schema: ${df.schema} + |Relation schema: ${relation.schema} + """.stripMargin) + val partitionColumnsInSpec = relation.partitionColumns.fieldNames + require( + partitionColumnsInSpec.sameElements(partitionColumns), + s"""Partition columns mismatch. + |Expected: ${partitionColumnsInSpec.mkString(", ")} + |Actual: ${partitionColumns.mkString(", ")} + """.stripMargin) + + val writerContainer = if (partitionColumns.isEmpty) { + new DefaultWriterContainer(relation, job, isAppend) + } else { + val output = df.queryExecution.executedPlan.output + val (partitionOutput, dataOutput) = + output.partition(a => partitionColumns.contains(a.name)) + + new DynamicPartitionWriterContainer( + relation, + job, + partitionOutput, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + isAppend) + } + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + + try { + sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writerContainer.writeRows _) + writerContainer.commitJob() + relation.refresh() + } catch { case cause: Throwable => + logError("Aborting job.", cause) + writerContainer.abortJob() + throw new SparkException("Job aborted.", cause) + } + } + } else { + logInfo("Skipping insertion into a relation that already exists.") + } + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index a7123dc845fa..219dae88e515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -17,23 +17,40 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.BaseRelation /** * Used to link a [[BaseRelation]] in to a logical query plan. + * + * Note that sometimes we need to use `LogicalRelation` to replace an existing leaf node without + * changing the output attributes' IDs. The `expectedOutputAttributes` parameter is used for + * this purpose. See https://issues.apache.org/jira/browse/SPARK-10741 for more details. */ -private[sql] case class LogicalRelation(relation: BaseRelation) - extends LeafNode - with MultiInstanceRelation { +case class LogicalRelation( + relation: BaseRelation, + expectedOutputAttributes: Option[Seq[Attribute]] = None) + extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = relation.schema.toAttributes + override val output: Seq[AttributeReference] = { + val attrs = relation.schema.toAttributes + expectedOutputAttributes.map { expectedAttrs => + assert(expectedAttrs.length == attrs.length) + attrs.zip(expectedAttrs).map { + // We should respect the attribute names provided by base relation and only use the + // exprId in `expectedOutputAttributes`. + // The reason is that, some relations(like parquet) will reconcile attribute names to + // workaround case insensitivity issue. + case (attr, expected) => attr.withExprId(expected.exprId) + } + }.getOrElse(attrs) + } // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output - case _ => false + case l @ LogicalRelation(otherRelation, _) => relation == otherRelation && output == l.output + case _ => false } override def hashCode: Int = { @@ -41,10 +58,15 @@ private[sql] case class LogicalRelation(relation: BaseRelation) } override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { - case LogicalRelation(otherRelation) => relation == otherRelation + case LogicalRelation(otherRelation, _) => relation == otherRelation case _ => false } + // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need + // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's + // expId can be different but the relation is still the same. + override lazy val cleanArgs: Seq[Any] = Seq(relation) + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = BigInt(relation.sizeInBytes) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 66dfcc308cec..81962f8d6378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -74,11 +75,16 @@ private[sql] object PartitioningUtils { private[sql] def parsePartitions( paths: Seq[Path], defaultPartitionName: String, - typeInference: Boolean): PartitionSpec = { + typeInference: Boolean, + basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. - val pathsWithPartitionValues = paths.flatMap { path => - parsePartition(path, defaultPartitionName, typeInference).map(path -> _) - } + val (partitionValues, optDiscoveredBasePaths) = paths.map { path => + parsePartition(path, defaultPartitionName, typeInference, basePaths) + }.unzip + + // We create pairs of (path -> path's partition value) here + // If the corresponding partition value is None, the pair will be skiped + val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _)) if (pathsWithPartitionValues.isEmpty) { // This dataset is not partitioned. @@ -86,6 +92,26 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. + + // Check if there is conflicting directory structure. + // For the paths such as: + // var paths = Seq( + // "hdfs://host:9000/invalidPath", + // "hdfs://host:9000/path/a=10/b=20", + // "hdfs://host:9000/path/a=10.5/b=hello") + // It will be recognised as conflicting directory structure: + // "hdfs://host:9000/invalidPath" + // "hdfs://host:9000/path" + val disvoeredBasePaths = optDiscoveredBasePaths.flatMap(x => x) + assert( + disvoeredBasePaths.distinct.size == 1, + "Conflicting directory structures detected. Suspicious paths:\b" + + disvoeredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + + "If provided paths are partition directories, please set " + + "\"basePath\" in the options of the data source to specify the " + + "root directory of the table. If there are multiple root directories, " + + "please load them separately and then union them.") + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) // Creates the StructType which represents the partition columns. @@ -109,12 +135,12 @@ private[sql] object PartitioningUtils { } /** - * Parses a single partition, returns column names and values of each partition column. For - * example, given: + * Parses a single partition, returns column names and values of each partition column, also + * the path when we stop partition discovery. For example, given: * {{{ * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 * }}} - * it returns: + * it returns the partition: * {{{ * PartitionValues( * Seq("a", "b", "c"), @@ -123,34 +149,63 @@ private[sql] object PartitioningUtils { * Literal.create("hello", StringType), * Literal.create(3.14, FloatType))) * }}} + * and the path when we stop the discovery is: + * {{{ + * hdfs://:/path/to/partition + * }}} */ private[sql] def parsePartition( path: Path, defaultPartitionName: String, - typeInference: Boolean): Option[PartitionValues] = { + typeInference: Boolean, + basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null - var chopped = path + // currentPath is the current path that we will use to parse partition column value. + var currentPath: Path = path while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left - // uncleaned. Here we simply ignore them. - if (chopped.getName.toLowerCase == "_temporary") { - return None + // uncleaned. Here we simply ignore them. + if (currentPath.getName.toLowerCase == "_temporary") { + return (None, None) } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) - maybeColumn.foreach(columns += _) - chopped = chopped.getParent - finished = maybeColumn.isEmpty || chopped.getParent == null + if (basePaths.contains(currentPath)) { + // If the currentPath is one of base paths. We should stop. + finished = true + } else { + // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. + // Once we get the string, we try to parse it and find the partition column and value. + val maybeColumn = + parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + maybeColumn.foreach(columns += _) + + // Now, we determine if we should stop. + // When we hit any of the following cases, we will stop: + // - In this iteration, we could not parse the value of partition column and value, + // i.e. maybeColumn is None, and columns is not empty. At here we check if columns is + // empty to handle cases like /table/a=1/_temporary/something (we need to find a=1 in + // this case). + // - After we get the new currentPath, this new currentPath represent the top level dir + // i.e. currentPath.getParent == null. For the example of "/table/a=1/", + // the top level dir is "/table". + finished = + (maybeColumn.isEmpty && !columns.isEmpty) || currentPath.getParent == null + + if (!finished) { + // For the above example, currentPath will be "/table/". + currentPath = currentPath.getParent + } + } } if (columns.isEmpty) { - None + (None, Some(path)) } else { val (columnNames, values) = columns.reverse.unzip - Some(PartitionValues(columnNames, values)) + (Some(PartitionValues(columnNames, values)), Some(currentPath)) } } @@ -270,6 +325,19 @@ private[sql] object PartitioningUtils { private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) + def validatePartitionColumnDataTypes( + schema: StructType, + partitionColumns: Array[String], + caseSensitive: Boolean): Unit = { + + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { + field => field.dataType match { + case _: AtomicType => // OK + case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") + } + } + } + /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala new file mode 100644 index 000000000000..e02ee6cd6b90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -0,0 +1,263 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Success, Failure, Try} + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.StringUtils + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + + +case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + + +object ResolvedDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { + // the provider format did not match any given registered aliases + case Nil => + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + if (provider == "avro" || provider == "com.databricks.spark.avro") { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please use Spark package " + + "http://spark-packages.org/package/databricks/spark-avro", + error) + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "http://spark-packages.org", + error) + } + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + if (caseInsensitiveOptions.contains("paths")) { + throw new AnalysisException(s"$className does not support paths option.") + } + dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema) + case dataSource: HadoopFsRelationProvider => + val maybePartitionsSchema = if (partitionColumns.isEmpty) { + None + } else { + Some(partitionColumnsSchema( + schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis)) + } + + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + if (caseInsensitiveOptions.contains("paths") && + caseInsensitiveOptions.contains("path")) { + throw new AnalysisException(s"Both path and paths options are present.") + } + caseInsensitiveOptions.get("paths") + .map(_.split("(? + val hdfsPath = new Path(pathString) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) + } + } + + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable + + dataSource.createRelation( + sqlContext, + paths, + Some(dataSchema), + maybePartitionsSchema, + caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") + } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + if (caseInsensitiveOptions.contains("paths")) { + throw new AnalysisException(s"$className does not support paths option.") + } + dataSource.createRelation(sqlContext, caseInsensitiveOptions) + case dataSource: HadoopFsRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + if (caseInsensitiveOptions.contains("paths") && + caseInsensitiveOptions.contains("path")) { + throw new AnalysisException(s"Both path and paths options are present.") + } + caseInsensitiveOptions.get("paths") + .map(_.split("(? + val hdfsPath = new Path(pathString) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) + } + } + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") + } + } + new ResolvedDataSource(clazz, relation) + } + + def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String], + caseSensitive: Boolean): StructType = { + val equality = columnNameEquality(caseSensitive) + StructType(partitionColumns.map { col => + schema.find(f => equality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { + if (caseSensitive) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ + def apply( + sqlContext: SQLContext, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: HadoopFsRelationProvider => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions("path")) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + val caseSensitive = sqlContext.conf.caseSensitiveAnalysis + PartitioningUtils.validatePartitionColumnDataTypes( + data.schema, partitionColumns, caseSensitive) + + val equality = columnNameEquality(caseSensitive) + val dataSchema = StructType( + data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + sqlContext.executePlan( + InsertIntoHadoopFsRelation( + r, + data.logicalPlan, + mode)).toRdd + r + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + ResolvedDataSource(clazz, relation) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 35e44cb59c1b..eea780cbaa7e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -26,22 +26,21 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} +import org.apache.spark.{Partition => SparkPartition, _} private[spark] class SqlNewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends SparkPartition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -60,18 +59,16 @@ private[spark] class SqlNewHadoopPartition( * and the executor side to the shared Hadoop Configuration. * * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be - * folded into core. + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ -private[spark] class SqlNewHadoopRDD[K, V]( - @transient sc : SparkContext, +private[spark] class SqlNewHadoopRDD[V: ClassTag]( + sqlContext: SQLContext, broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], + @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], + inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[(K, V)](sc, Nil) + extends RDD[V](sqlContext.sparkContext, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -90,7 +87,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { @@ -100,6 +97,11 @@ private[spark] class SqlNewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -119,9 +121,9 @@ private[spark] class SqlNewHadoopRDD[K, V]( } override def compute( - theSplit: SparkPartition, - context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter = new Iterator[(K, V)] { + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { + val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf(isDriverSide = false) @@ -131,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[K, V]( // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -146,25 +148,48 @@ private[spark] class SqlNewHadoopRDD[K, V]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + if (!parquetReader.tryInitialize( + split.serializableHadoopSplit.value, hadoopAttemptContext)) { + parquetReader.close() + } else { + reader = parquetReader.asInstanceOf[RecordReader[Void, V]] + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + + private[this] var havePair = false + private[this] var finished = false override def hasNext: Boolean = { + if (context.isInterrupted) { + throw new TaskKilledException + } if (!finished && !havePair) { finished = !reader.nextKeyValue if (finished) { @@ -178,7 +203,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( !finished } - override def next(): (K, V) = { + override def next(): V = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -186,49 +211,43 @@ private[spark] class SqlNewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } - (reader.getCurrentKey, reader.getCurrentValue) + reader.getCurrentValue } private def close() { - try { - if (reader != null) { + if (reader != null) { + SqlNewHadoopRDDState.unsetInputFileName() + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { reader.close() - reader = null - - SqlNewHadoopRDD.unsetInputFileName() - - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) } - } + } finally { + reader = null } - } catch { - case e: Exception => { - if (!Utils.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) } } } } } - new InterruptibleIterator(context, iter) - } - - /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ - @DeveloperApi - def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + iter } override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { @@ -256,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[K, V]( } super.persist(storageLevel) } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala new file mode 100644 index 000000000000..ad5536725889 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.util.SerializableConfiguration + + +private[sql] abstract class BaseWriterContainer( + @transient val relation: HadoopFsRelation, + @transient private val job: Job, + isAppend: Boolean) + extends SparkHadoopMapReduceUtil + with Logging + with Serializable { + + protected val dataSchema = relation.dataSchema + + protected val serializableConf = + new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() + + // This is only used on driver side. + @transient private val jobContext: JobContext = job + + private val speculationEnabled: Boolean = + relation.sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + // The following fields are initialized and used on both driver and executor side. + @transient protected var outputCommitter: OutputCommitter = _ + @transient private var jobId: JobID = _ + @transient private var taskId: TaskID = _ + @transient private var taskAttemptId: TaskAttemptID = _ + @transient protected var taskAttemptContext: TaskAttemptContext = _ + + protected val outputPath: String = { + assert( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + relation.paths.head + } + + protected var outputWriterFactory: OutputWriterFactory = _ + + private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit + + def driverSideSetup(): Unit = { + setupIDs(0, 0, 0) + setupConf() + + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + SparkHadoopUtil.get.getConfigurationFromJobContext(job). + set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. + outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + outputFormatClass = job.getOutputFormatClass + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupJob(jobContext) + } + + def executorSideSetup(taskContext: TaskContext): Unit = { + setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) + setupConf() + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupTask(taskAttemptContext) + } + + protected def getWorkPath: String = { + outputCommitter match { + // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. + case f: MapReduceFileOutputCommitter => f.getWorkPath.toString + case _ => outputPath + } + } + + protected def newOutputWriter(path: String): OutputWriter = { + try { + outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + } catch { + case e: org.apache.hadoop.fs.FileAlreadyExistsException => + if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { + // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + // attempts, the task will fail because the output file is created from a prior attempt. + // This often means the most visible error to the user is misleading. Augment the error + // to tell the user to look for the actual error. + throw new SparkException("The output file already exists but this could be due to a " + + "failure from an earlier attempt. Look through the earlier logs or stage page for " + + "the first error.\n File exists error: " + e) + } + throw e + } + } + + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + // + // See SPARK-8578 for more details + logInfo( + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else if (speculationEnabled) { + // When speculation is enabled, it's not safe to use customized output committer classes, + // especially direct output committers (e.g. `DirectParquetOutputCommitter`). + // + // See SPARK-9899 for more details. + logInfo( + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + + "because spark.speculation is configured to be true.") + defaultOutputCommitter + } else { + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val committerClass = configuration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter + } + } + } + + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { + this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) + this.taskId = new TaskID(this.jobId, true, splitId) + // scalastyle:off jobcontext + this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + // scalastyle:on jobcontext + } + + private def setupConf(): Unit = { + serializableConf.value.set("mapred.job.id", jobId.toString) + serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + serializableConf.value.set("mapred.task.id", taskAttemptId.toString) + serializableConf.value.setBoolean("mapred.task.is.map", true) + serializableConf.value.setInt("mapred.task.partition", 0) + } + + def commitTask(): Unit = { + SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) + } + + def abortTask(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } + logError(s"Task attempt $taskAttemptId aborted.") + } + + def commitJob(): Unit = { + outputCommitter.commitJob(jobContext) + logInfo(s"Job $jobId committed.") + } + + def abortJob(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } + logError(s"Job $jobId aborted.") + } +} + +/** + * A writer that writes all of the rows in a partition to a single file. + */ +private[sql] class DefaultWriterContainer( + relation: HadoopFsRelation, + job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + executorSideSetup(taskContext) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set("spark.sql.sources.output.path", outputPath) + val writer = newOutputWriter(getWorkPath) + writer.initConverter(dataSchema) + + var writerClosed = false + + // If anything below fails, we should abort the task. + try { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + def commitTask(): Unit = { + try { + assert(writer != null, "OutputWriter instance should have been initialized") + if (!writerClosed) { + writer.close() + writerClosed = true + } + super.commitTask() + } catch { + case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and + // will cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + if (!writerClosed) { + writer.close() + writerClosed = true + } + } finally { + super.abortTask() + } + } + } +} + +/** + * A writer that dynamically opens files based on the given partition columns. Internally this is + * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the + * writer externally sorts the remaining rows and then writes out them out one file at a time. + */ +private[sql] class DynamicPartitionWriterContainer( + relation: HadoopFsRelation, + job: Job, + partitionColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + inputSchema: Seq[Attribute], + defaultPartitionName: String, + maxOpenFiles: Int, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + var outputWritersCleared = false + + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF( + PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartitionName), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + } + + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // If anything below fails, we should abort the task. + try { + // This will be filled in if we have to fall back on sorting. + var sorter: UnsafeKVExternalSorter = null + while (iterator.hasNext && sorter == null) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + var currentWriter = outputWriters.get(currentKey) + + if (currentWriter == null) { + if (outputWriters.size < maxOpenFiles) { + currentWriter = newOutputWriter(currentKey) + outputWriters.put(currentKey.copy(), currentWriter) + currentWriter.writeInternal(getOutputRow(inputRow)) + } else { + logInfo(s"Maximum partitions reached, falling back on sorting.") + sorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionColumns), + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } + } else { + currentWriter.writeInternal(getOutputRow(inputRow)) + } + } + + // If the sorter is not null that means that we reached the maxFiles above and need to finish + // using external sort. + if (sorter != null) { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): OutputWriter = { + val partitionPath = getPartitionString(key).getString(0) + val path = new Path(getWorkPath, partitionPath) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + val newWriter = super.newOutputWriter(path.toString) + newWriter.initConverter(dataSchema) + newWriter + } + + def clearOutputWriters(): Unit = { + if (!outputWritersCleared) { + outputWriters.asScala.values.foreach(_.close()) + outputWriters.clear() + outputWritersCleared = true + } + } + + def commitTask(): Unit = { + try { + clearOutputWriters() + super.commitTask() + } catch { + case cause: Throwable => + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + clearOutputWriters() + } finally { + super.abortTask() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala deleted file mode 100644 index d551f386eee6..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ /dev/null @@ -1,599 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.SerializableConfiguration - - -private[sql] case class InsertIntoDataSource( - logicalRelation: LogicalRelation, - query: LogicalPlan, - overwrite: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = DataFrame(sqlContext, query) - // Apply the schema of the existing table to the new data. - val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) - - // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) - - Seq.empty[Row] - } -} - -/** - * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. - * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a - * single write job, and owns a UUID that identifies this job. Each concrete implementation of - * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for - * each task output file. This UUID is passed to executor side via a property named - * `spark.sql.sources.writeJobUUID`. - * - * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] - * are used to write to normal tables and tables with dynamic partitions. - * - * Basic work flow of this command is: - * - * 1. Driver side setup, including output committer initialization and data source specific - * preparation work for the write job to be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ -private[sql] case class InsertIntoHadoopFsRelation( - @transient relation: HadoopFsRelation, - @transient query: LogicalPlan, - mode: SaveMode) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - require( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration - val outputPath = new Path(relation.paths.head) - val fs = outputPath.getFileSystem(hadoopConf) - val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - - val pathExists = fs.exists(qualifiedOutputPath) - val doInsertion = (mode, pathExists) match { - case (SaveMode.ErrorIfExists, true) => - throw new AnalysisException(s"path $qualifiedOutputPath already exists.") - case (SaveMode.Overwrite, true) => - fs.delete(qualifiedOutputPath, true) - true - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - case (s, exists) => - throw new IllegalStateException(s"unsupported save mode $s ($exists)") - } - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) - - if (doInsertion) { - val job = new Job(hadoopConf) - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - - // We create a DataFrame by applying the schema of relation to the data to make sure. - // We are writing data based on the expected schema, - val df = { - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). We - // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can - // safely apply the schema of r.schema to the data. - val project = Project( - relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) - - sqlContext.internalCreateDataFrame( - DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema) - } - - val partitionColumns = relation.partitionColumns.fieldNames - if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job, isAppend), df) - } else { - val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) - insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) - } - } - - Seq.empty[Row] - } - - /** - * Inserts the content of the [[DataFrame]] into a table without any partitioning columns. - */ - private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { - // Uses local vals for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow) - .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - /** - * Inserts the content of the [[DataFrame]] into a table with partitioning columns. - */ - private def insertWithDynamicPartitions( - sqlContext: SQLContext, - writerContainer: BaseWriterContainer, - df: DataFrame, - partitionColumns: Array[String]): Unit = { - // Uses a local val for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - require( - df.schema == relation.schema, - s"""DataFrame must have the same schema as the relation to which is inserted. - |DataFrame schema: ${df.schema} - |Relation schema: ${relation.schema} - """.stripMargin) - - val partitionColumnsInSpec = relation.partitionColumns.fieldNames - require( - partitionColumnsInSpec.sameElements(partitionColumns), - s"""Partition columns mismatch. - |Expected: ${partitionColumnsInSpec.mkString(", ")} - |Actual: ${partitionColumns.mkString(", ")} - """.stripMargin) - - val output = df.queryExecution.executedPlan.output - val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - val codegenEnabled = df.sqlContext.conf.codegenEnabled - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - // Projects all partition columns and casts them to strings to build partition directories. - val partitionCasts = partitionOutput.map(Cast(_, StringType)) - val partitionProj = newProjection(codegenEnabled, partitionCasts, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = converter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataProj(internalRow) - writerContainer.outputWriterForRow(partitionPart) - .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - // This is copied from SparkPlan, probably should move this to a more general place. - private def newProjection( - codegenEnabled: Boolean, - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (sys.props.contains("spark.testing")) { - throw e - } else { - log.error("failed to generate projection, fallback to interpreted", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } -} - -private[sql] abstract class BaseWriterContainer( - @transient val relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { - - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) - - // This UUID is used to avoid output file name collision between different appending write jobs. - // These jobs may belong to different SparkContext instances. Concrete data source implementations - // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). - // The reason why this ID is used to identify a job rather than a single task output file is - // that, speculative tasks must generate the same output file name as the original task. - private val uniqueWriteJobId = UUID.randomUUID() - - // This is only used on driver side. - @transient private val jobContext: JobContext = job - - // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: OutputCommitter = _ - @transient private var jobId: JobID = _ - @transient private var taskId: TaskID = _ - @transient private var taskAttemptId: TaskAttemptID = _ - @transient protected var taskAttemptContext: TaskAttemptContext = _ - - protected val outputPath: String = { - assert( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - relation.paths.head - } - - protected val dataSchema = relation.dataSchema - - protected var outputWriterFactory: OutputWriterFactory = _ - - private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ - - def driverSideSetup(): Unit = { - setupIDs(0, 0, 0) - setupConf() - - // This UUID is sent to executor side together with the serialized `Configuration` object within - // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate - // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) - - // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor - // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, - // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. - // - // Also, the `prepareJobForWrite` call must happen before initializing output format and output - // committer, since their initialization involve the job configuration, which can be potentially - // decorated in `prepareJobForWrite`. - outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - - outputFormatClass = job.getOutputFormatClass - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupJob(jobContext) - } - - def executorSideSetup(taskContext: TaskContext): Unit = { - setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) - setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupTask(taskAttemptContext) - initWriters() - } - - protected def getWorkPath: String = { - outputCommitter match { - // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: MapReduceFileOutputCommitter => f.getWorkPath.toString - case _ => outputPath - } - } - - private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the the appending job fails. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } - - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { - this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - } - - private def setupConf(): Unit = { - serializableConf.value.set("mapred.job.id", jobId.toString) - serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - serializableConf.value.set("mapred.task.id", taskAttemptId.toString) - serializableConf.value.setBoolean("mapred.task.is.map", true) - serializableConf.value.setInt("mapred.task.partition", 0) - } - - // Called on executor side when writing rows - def outputWriterForRow(row: InternalRow): OutputWriter - - protected def initWriters(): Unit - - def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) - } - - def abortTask(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortTask(taskAttemptContext) - } - logError(s"Task attempt $taskAttemptId aborted.") - } - - def commitJob(): Unit = { - outputCommitter.commitJob(jobContext) - logInfo(s"Job $jobId committed.") - } - - def abortJob(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) - } - logError(s"Job $jobId aborted.") - } -} - -private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - @transient private var writer: OutputWriter = _ - - override protected def initWriters(): Unit = { - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) - } - - override def outputWriterForRow(row: InternalRow): OutputWriter = writer - - override def commitTask(): Unit = { - try { - assert(writer != null, "OutputWriter instance should have been initialized") - writer.close() - super.commitTask() - } catch { case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will - // cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - // It's possible that the task fails before `writer` gets initialized - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() - } - } -} - -private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - partitionColumns: Array[String], - defaultPartitionName: String, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - // All output writers are created on executor side. - @transient protected var outputWriters: java.util.HashMap[String, OutputWriter] = _ - - override protected def initWriters(): Unit = { - outputWriters = new java.util.HashMap[String, OutputWriter] - } - - // The `row` argument is supposed to only contain partition column values which have been casted - // to strings. - override def outputWriterForRow(row: InternalRow): OutputWriter = { - val partitionPath = { - val partitionPathBuilder = new StringBuilder - var i = 0 - - while (i < partitionColumns.length) { - val col = partitionColumns(i) - val partitionValueString = { - val string = row.getUTF8String(i) - if (string.eq(null)) { - defaultPartitionName - } else { - PartitioningUtils.escapePathName(string.toString) - } - } - - if (i > 0) { - partitionPathBuilder.append(Path.SEPARATOR_CHAR) - } - - partitionPathBuilder.append(s"$col=$partitionValueString") - i += 1 - } - - partitionPathBuilder.toString() - } - - val writer = outputWriters.get(partitionPath) - if (writer.eq(null)) { - val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) - outputWriters.put(partitionPath, newWriter) - newWriter - } else { - writer - } - } - - private def clearOutputWriters(): Unit = { - if (!outputWriters.isEmpty) { - asScalaIterator(outputWriters.values().iterator()).foreach(_.close()) - outputWriters.clear() - } - } - - override def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad57b..e7deeff13dc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,340 +17,12 @@ package org.apache.spark.sql.execution.datasources -import scala.language.{existentials, implicitConversions} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } - } - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ -private[sql] case class DescribeCommand( +case class DescribeCommand( table: LogicalPlan, isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, @@ -370,7 +43,8 @@ private[sql] case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) } /** @@ -378,8 +52,8 @@ private[sql] case class DescribeCommand( * @param allowExisting If it is true, we will do nothing when the table already exists. * If it is false, an exception will be thrown */ -private[sql] case class CreateTableUsing( - tableName: String, +case class CreateTableUsing( + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, @@ -397,8 +71,8 @@ private[sql] case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -private[sql] case class CreateTableUsingAsSelect( - tableName: String, +case class CreateTableUsingAsSelect( + tableIdent: TableIdentifier, provider: String, temporary: Boolean, partitionColumns: Array[String], @@ -406,12 +80,10 @@ private[sql] case class CreateTableUsingAsSelect( options: Map[String, String], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] - // TODO: Override resolved after we support databaseName. - // override lazy val resolved = databaseName != None && childrenResolved } -private[sql] case class CreateTempTableUsing( - tableName: String, +case class CreateTempTableUsing( + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -419,14 +91,16 @@ private[sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + Seq.empty[Row] } } -private[sql] case class CreateTempTableUsingAsSelect( - tableName: String, +case class CreateTempTableUsingAsSelect( + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -436,14 +110,15 @@ private[sql] case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] } } -private[sql] case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -452,7 +127,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent.toSeq) + val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { @@ -472,7 +147,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) /** * Builds a map in which keys are case insensitive */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -490,4 +165,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. */ -protected[sql] class DDLException(message: String) extends Exception(message) +class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala new file mode 100644 index 000000000000..f522303be94a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala new file mode 100644 index 000000000000..7ccd61ed469e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Driver, DriverManager} + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper. + */ +object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala new file mode 100644 index 000000000000..18263fe227d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.util.Properties + +/** + * A wrapper for a JDBC Driver to work around SPARK-6913. + * + * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by + * Spark ClassLoader. + */ +class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + def getParentLogger: java.util.logging.Logger = { + throw new SQLFeatureNotSupportedException( + s"${this.getClass.getName}.getParentLogger is not yet implemented.") + } + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3cf70db6b7b0..2d38562e0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -15,17 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,7 +120,7 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = DriverManager.getConnection(url, properties) + val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() try { @@ -170,7 +173,8 @@ private[sql] object JDBCRDD extends Logging { * getConnector is run on the driver code, while the function it returns * is run on the executor. * - * @param driver - The class name of the JDBC driver for the given url. + * @param driver - The class name of the JDBC driver for the given url, or null if the class name + * is not necessary. * @param url - The JDBC url to connect to. * * @return A function that loads the driver and connects to the url. @@ -180,9 +184,8 @@ private[sql] object JDBCRDD extends Logging { try { if (driver != null) DriverRegistry.register(driver) } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } @@ -223,6 +226,7 @@ private[sql] object JDBCRDD extends Logging { quotedColumns, filters, parts, + url, properties) } } @@ -240,6 +244,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -261,7 +266,9 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" case _ => value } @@ -274,10 +281,13 @@ private[sql] class JDBCRDD( */ private def compileFilter(f: Filter): String = f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" case _ => null } @@ -323,30 +333,32 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** * Maps a StructType to a type tag list. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + def getConversions(schema: StructType): Array[JDBCConversion] = + schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + + private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } - /** * Runs the SQL query against the JDBC driver. + * */ override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = new Iterator[InternalRow] { @@ -358,6 +370,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -368,7 +383,7 @@ private[sql] class JDBCRDD( val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty("fetchSize", "0").toInt + val fetchSize = properties.getProperty("fetchsize", "0").toInt stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() @@ -419,16 +434,44 @@ private[sql] class JDBCRDD( mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + case ArrayConversion(elementConversion) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion match { + case TimestampConversion => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + case StringConversion => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + case DateConversion => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + case DecimalConversion(p, s) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) + } + case BinaryLongConversion => + throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case _: ArrayConversion => + throw new IllegalArgumentException("Nested arrays unsupported") + case _ => array.asInstanceOf[Array[Any]] + } + mutableRow.update(i, new GenericArrayData(data)) + } else { + mutableRow.update(i, null) + } } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 @@ -458,12 +501,20 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.isClosed && !conn.getAutoCommit) { + try { + conn.commit() + } catch { + case NonFatal(e) => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) } + closed = true } override def hasNext: Boolean = { @@ -487,4 +538,12 @@ private[sql] class JDBCRDD( nextValue } } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala similarity index 72% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 41d0ecb4bbfb..f9300dc2cb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties @@ -77,42 +77,6 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider { - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - private[sql] case class JDBCRelation( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala new file mode 100644 index 000000000000..252f1cfd5d9c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, PreparedStatement} +import java.util.Properties + +import scala.util.Try +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Util functions for JDBC tables. + */ +object JdbcUtils extends Logging { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.createStatement.executeUpdate(s"DROP TABLE $table") + } + + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString()) + } + + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } + } + + private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition( + getConnection: () => Connection, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int], + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { + val conn = getConnection() + var committed = false + val supportsTransactions = try { + conn.getMetaData().supportsDataManipulationTransactionsOnly() || + conn.getMetaData().supportsDataDefinitionAndDataManipulationTransactions() + } catch { + case NonFatal(e) => + logWarning("Exception while detecting transaction support", e) + true + } + + try { + if (supportsTransactions) { + conn.setAutoCommit(false) // Everything in the same db transaction. + } + val stmt = insertStatement(conn, table, rddSchema) + try { + var rowCount = 0 + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(et, _) => + val array = conn.createArrayOf( + getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + row.getSeq[AnyRef](i).toArray) + stmt.setArray(i + 1, array) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.addBatch() + rowCount += 1 + if (rowCount % batchSize == 0) { + stmt.executeBatch() + rowCount = 0 + } + } + if (rowCount > 0) { + stmt.executeBatch() + } + } finally { + stmt.close() + } + if (supportsTransactions) { + conn.commit() + } + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + if (supportsTransactions) { + conn.rollback() + } + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val dialect = JdbcDialects.get(url) + df.schema.fields foreach { field => { + val name = field.name + val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + getJdbcType(field.dataType, dialect).jdbcNullType + } + + val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val batchSize = properties.getProperty("batchsize", "1000").toInt + df.foreachPartition { iterator => + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 04ab5e221788..59ba4ae2cba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -15,47 +15,56 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +private[json] object InferSchema { -private[sql] object InferSchema { /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def apply( + def infer( json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) { + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): StructType = { + require(configOptions.samplingRatio > 0, + s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") + val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { - json.sample(withReplacement = false, samplingRatio, 1) + json.sample(withReplacement = false, configOptions.samplingRatio, 1) } // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) iter.map { row => try { - val parser = factory.createParser(row) - parser.nextToken() - inferField(parser) + Utils.tryWithResource(factory.createParser(row)) { parser => + parser.nextToken() + inferField(parser, configOptions) + } } catch { case _: JsonParseException => StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + }.treeAggregate[DataType]( + StructType(Seq()))( + compatibleRootType(columnNameOfCorruptRecords), + compatibleRootType(columnNameOfCorruptRecords)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -68,14 +77,14 @@ private[sql] object InferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -90,7 +99,10 @@ private[sql] object InferSchema { case START_OBJECT => val builder = Seq.newBuilder[StructField] while (nextUntil(parser, END_OBJECT)) { - builder += StructField(parser.getCurrentName, inferField(parser), nullable = true) + builder += StructField( + parser.getCurrentName, + inferField(parser, configOptions), + nullable = true) } StructType(builder.result().sortBy(_.name)) @@ -101,11 +113,16 @@ private[sql] object InferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser)) + elementType = compatibleType( + elementType, inferField(parser, configOptions)) } ArrayType(elementType) + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType + + case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ parser.getNumberType match { @@ -113,8 +130,12 @@ private[sql] object InferSchema { case INT | LONG => LongType // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT - case FLOAT | DOUBLE => DoubleType + case BIG_INTEGER | BIG_DECIMAL => + val v = parser.getDecimalValue + DecimalType(v.precision(), v.scale()) + case FLOAT | DOUBLE => + // TODO(davies): Should we use decimal if possible? + DoubleType } case VALUE_TRUE | VALUE_FALSE => BooleanType @@ -152,28 +173,63 @@ private[sql] object InferSchema { case other => Some(other) } + private def withCorruptField( + struct: StructType, + columnNameOfCorruptRecords: String): StructType = { + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + } else { + // Otherwise, just return this struct. + struct + } + } + /** * Remove top-level ArrayType wrappers and merge the remaining schemas */ - private def compatibleRootType: (DataType, DataType) => DataType = { - case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) + private def compatibleRootType( + columnNameOfCorruptRecords: String): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) } /** * Returns the most general data type for two given data types. */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType(t1: DataType, t2: DataType): DataType = { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. case (DoubleType, t: DecimalType) => - if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + DoubleType case (t: DecimalType, DoubleType) => - if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + DoubleType + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + DoubleType + } else { + DecimalType(range + scale, scale) + } case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala new file mode 100644 index 000000000000..c132ead20e7d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import com.fasterxml.jackson.core.{JsonParser, JsonFactory} + +/** + * Options for the JSON data source. + * + * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. + */ +case class JSONOptions( + samplingRatio: Double = 1.0, + primitivesAsString: Boolean = false, + allowComments: Boolean = false, + allowUnquotedFieldNames: Boolean = false, + allowSingleQuotes: Boolean = true, + allowNumericLeadingZeros: Boolean = false, + allowNonNumericNumbers: Boolean = false) { + + /** Sets config options on a Jackson [[JsonFactory]]. */ + def setJacksonOptions(factory: JsonFactory): Unit = { + factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) + factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames) + factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) + factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) + factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + } +} + + +object JSONOptions { + def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( + samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), + primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), + allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false), + allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), + allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), + allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), + allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala new file mode 100644 index 000000000000..3e61ba35bea8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import java.io.CharArrayWriter + +import com.fasterxml.jackson.core.JsonFactory +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} + +import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.util.SerializableConfiguration + + +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "json" + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + + new JSONRelation( + inputRDD = None, + maybeDataSchema = dataSchema, + maybePartitionSpec = None, + userDefinedPartitionColumns = partitionColumns, + paths = paths, + parameters = parameters)(sqlContext) + } +} + +private[sql] class JSONRelation( + val inputRDD: Option[RDD[String]], + val maybeDataSchema: Option[StructType], + val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + override val paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String]) + (@transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec, parameters) { + + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) + + /** Constraints to be imposed on schema to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } + + override val needConversion: Boolean = false + + private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + + val paths = inputPaths.map(_.getPath) + + if (paths.nonEmpty) { + FileInputFormat.setInputPaths(job, paths: _*) + } + + sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + classOf[TextInputFormat], + classOf[LongWritable], + classOf[Text]).map(_._2.toString) // get the text line + } + + override lazy val dataSchema: StructType = { + val jsonSchema = maybeDataSchema.getOrElse { + val files = cachedLeafStatuses().filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.toArray + InferSchema.infer( + inputRDD.getOrElse(createBaseRdd(files)), + sqlContext.conf.columnNameOfCorruptRecord, + options) + } + checkConstraints(jsonSchema) + + jsonSchema + } + + override private[sql] def buildInternalScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { + val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) + val rows = JacksonParser.parse( + inputRDD.getOrElse(createBaseRdd(inputPaths)), + requiredDataSchema, + sqlContext.conf.columnNameOfCorruptRecord, + options) + + rows.mapPartitions { iterator => + val unsafeProjection = UnsafeProjection.create(requiredDataSchema) + iterator.map(unsafeProjection) + } + } + + override def equals(other: Any): Boolean = other match { + case that: JSONRelation => + ((inputRDD, that.inputRDD) match { + case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd + case (None, None) => true + case _ => false + }) && paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + inputRDD, + paths.toSet, + dataSchema, + schema, + partitionColumns) + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new JsonOutputWriter(path, dataSchema, context) + } + } + } +} + +private[json] class JsonOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil with Logging { + + private[this] val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private[this] val result = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + JacksonGenerator(dataSchema, gen)(row) + gen.flush() + + result.set(writer.toString) + writer.reset() + + recordWriter.write(NullWritable.get(), result) + } + + override def close(): Unit = { + gen.close() + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 1e6b1198d245..3f34520afe6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -15,7 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} import scala.collection.Map @@ -25,51 +28,65 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { - /** Transforms a single Row to JSON using Jackson - * - * @param rowSchema the schema object used for conversion - * @param gen a JsonGenerator object - * @param row The row to convert - */ - def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { + /** Transforms a single InternalRow to JSON using Jackson + * + * TODO: make the code shared with the other apply method. + * + * @param rowSchema the schema object used for conversion + * @param gen a JsonGenerator object + * @param row The row to convert + */ + def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (StringType, v) => gen.writeString(v.toString) + case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) case (DoubleType, v: Double) => gen.writeNumber(v) case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal) case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) + case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) - case (ArrayType(ty, _), v: Seq[_]) => + case (ArrayType(ty, _), v: ArrayData) => gen.writeStartArray() - v.foreach(valWriter(ty, _)) + v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() - case (MapType(kv, vv, _), v: Map[_, _]) => + case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) gen.writeEndObject() - case (StructType(ty), v: Row) => + case (StructType(ty), v: InternalRow) => gen.writeStartObject() - ty.zip(v.toSeq).foreach { - case (_, null) => - case (field, v) => + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { gen.writeFieldName(field.name) - valWriter(field.dataType, v) + valWriter(field.dataType, value) + } + i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala similarity index 58% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index bf0448ee9645..55a1c24e9e00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -15,34 +15,41 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream +import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) -private[sql] object JacksonParser { - def apply( - json: RDD[String], +object JacksonParser { + + def parse( + input: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, schema, columnNameOfCorruptRecords) + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): RDD[InternalRow] = { + + input.mapPartitions { iter => + parseJson(iter, schema, columnNameOfCorruptRecords, configOptions) + } } /** * Parse the current token (and related children) according to a desired schema */ - private[sql] def convertField( + def convertField( factory: JsonFactory, parser: JsonParser, schema: DataType): Any = { @@ -62,10 +69,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => @@ -73,20 +93,47 @@ private[sql] object JacksonParser { case (_, StringType) => val writer = new ByteArrayOutputStream() - val generator = factory.createGenerator(writer, JsonEncoding.UTF8) - generator.copyCurrentStructure(parser) - generator.close() + Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { + generator => generator.copyCurrentStructure(parser) + } UTF8String.fromBytes(writer.toByteArray) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => parser.getFloatValue + case (VALUE_STRING, FloatType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toFloat + } else { + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + } + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => parser.getDoubleValue - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) => - // TODO: add fixed precision and scale handling - Decimal(parser.getDecimalValue) + case (VALUE_STRING, DoubleType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toDouble + } else { + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + } + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => + Decimal(parser.getDecimalValue, dt.precision, dt.scale) case (VALUE_NUMBER_INT, ByteType) => parser.getByteValue @@ -126,7 +173,14 @@ private[sql] object JacksonParser { convertMap(factory, parser, kt) case (_, udt: UserDefinedType[_]) => - udt.deserialize(convertField(factory, parser, udt.sqlType)) + convertField(factory, parser, udt.sqlType) + + case (token, dataType) => + // We cannot parse this token based on the given data type. So, we throw a + // SparkSQLJsonProcessingException and this exception will be caught by + // parseJson method. + throw new SparkSQLJsonProcessingException( + s"Failed to parse a value for data type $dataType (current token: $token).") } } @@ -182,9 +236,10 @@ private[sql] object JacksonParser { } private def parseJson( - json: RDD[String], + input: Iterator[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): Iterator[InternalRow] = { def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present @@ -197,31 +252,35 @@ private[sql] object JacksonParser { Seq(row) } - json.mapPartitions { iter => - val factory = new JsonFactory() + val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) - iter.flatMap { record => + input.flatMap { record => + if (record.trim.isEmpty) { + Nil + } else { try { - val parser = factory.createParser(record) - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + failedRecord(record) + } } } catch { case _: JsonProcessingException => failedRecord(record) + case _: SparkSQLJsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala index fde96852ce68..005546f37dda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala new file mode 100644 index 000000000000..a958373eb769 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** + * A Parquet [[ReadSupport]] implementation for reading Parquet records as Catalyst + * [[InternalRow]]s. + * + * The API interface of [[ReadSupport]] is a little bit over complicated because of historical + * reasons. In older versions of parquet-mr (say 1.6.0rc3 and prior), [[ReadSupport]] need to be + * instantiated and initialized twice on both driver side and executor side. The [[init()]] method + * is for driver side initialization, while [[prepareForRead()]] is for executor side. However, + * starting from parquet-mr 1.6.0, it's no longer the case, and [[ReadSupport]] is only instantiated + * and initialized on executor side. So, theoretically, now it's totally fine to combine these two + * methods into a single initialization method. The only reason (I could think of) to still have + * them here is for parquet-mr API backwards-compatibility. + * + * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] + * to [[prepareForRead()]], but use a private `var` for simplicity. + */ +private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { + private var catalystRequestedSchema: StructType = _ + + /** + * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record + * readers. Responsible for figuring out Parquet requested schema used for column pruning. + */ + override def init(context: InitContext): ReadContext = { + catalystRequestedSchema = { + // scalastyle:off jobcontext + val conf = context.getConfiguration + // scalastyle:on jobcontext + val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + assert(schemaString != null, "Parquet requested schema not set.") + StructType.fromString(schemaString) + } + + val parquetRequestedSchema = + CatalystReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) + } + + /** + * Called on executor side after [[init()]], before instantiating actual Parquet record readers. + * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet + * records to Catalyst [[InternalRow]]s. + */ + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + log.debug(s"Preparing for read Parquet file with message type: $fileSchema") + val parquetRequestedSchema = readContext.getRequestedSchema + + logInfo { + s"""Going to read the following fields from the Parquet file: + | + |Parquet form: + |$parquetRequestedSchema + |Catalyst form: + |$catalystRequestedSchema + """.stripMargin + } + + new CatalystRecordMaterializer( + parquetRequestedSchema, + CatalystReadSupport.expandUDT(catalystRequestedSchema)) + } +} + +private[parquet] object CatalystReadSupport { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist + * in `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { + val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + Types + .buildMessage() + .addFields(clippedParquetFields: _*) + .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + } + + private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType) + + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t) + + case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a + * [[StructType]]. + */ + private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList.getType(0).isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if ( + repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple" + ) { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. + */ + private def clipParquetMapType( + parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(clipParquetType(parquetKeyType, keyType)) + .addField(clipParquetType(parquetValueType, valueType)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A clipped [[GroupType]], which has at least one field. + * @note Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, structType: StructType): Seq[Type] = { + val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + val toParquet = new CatalystSchemaConverter(writeLegacyParquetFormat = false) + structType.map { f => + parquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType)) + .getOrElse(toParquet.convertField(f)) + } + } + + def expandUDT(schema: StructType): StructType = { + def expand(dataType: DataType): DataType = { + dataType match { + case t: ArrayType => + t.copy(elementType = expand(t.elementType)) + + case t: MapType => + t.copy( + keyType = expand(t.keyType), + valueType = expand(t.valueType)) + + case t: StructType => + val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType))) + t.copy(fields = expandedFields) + + case t: UserDefinedType[_] => + t.sqlType + + case t => + t + } + } + + expand(schema).asInstanceOf[StructType] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index 84f1dccfeb78..eeead9f5d88a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -35,7 +35,7 @@ private[parquet] class CatalystRecordMaterializer( private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - override def getCurrentRecord: InternalRow = rootConverter.currentRow + override def getCurrentRecord: InternalRow = rootConverter.currentRecord override def getRootConverter: GroupConverter = rootConverter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 6938b071065c..8851bc23cd05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -15,23 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder -import scala.collection.JavaConversions._ -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY} +import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,6 +43,12 @@ import org.apache.spark.unsafe.types.UTF8String * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { + /** Called before a record field is being converted */ + def start(): Unit = () + + /** Called after a record field is being converted */ + def end(): Unit = () + def set(value: Any): Unit = () def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) @@ -55,22 +62,91 @@ private[parquet] trait ParentContainerUpdater { /** A no-op updater used for root converter (who doesn't have a parent). */ private[parquet] object NoopUpdater extends ParentContainerUpdater +private[parquet] trait HasParentContainerUpdater { + def updater: ParentContainerUpdater +} + +/** + * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + */ +private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater + +/** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ +private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) + extends PrimitiveConverter with HasParentContainerUpdater { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) +} + /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. - * Since any Parquet record is also a struct, this converter can also be used as root converter. + * A [[CatalystRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root + * converter. Take the following Parquet type as an example: + * {{{ + * message root { + * required int32 f1; + * optional group f2 { + * required double f21; + * optional binary f22 (utf8); + * } + * } + * }}} + * 5 converters will be created: + * + * - a root [[CatalystRowConverter]] for [[MessageType]] `root`, which contains: + * - a [[CatalystPrimitiveConverter]] for required [[INT_32]] field `f1`, and + * - a nested [[CatalystRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[CatalystPrimitiveConverter]] for required [[DOUBLE]] field `f21`, and + * - a [[CatalystStringConverter]] for optional [[UTF8]] string field `f22` * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. * * @param parquetType Parquet schema of Parquet records - * @param catalystType Spark SQL schema that corresponds to the Parquet record type + * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined + * types should have been expanded. * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) with Logging { + + assert( + parquetType.getFieldCount == catalystType.length, + s"""Field counts of the Parquet schema and the Catalyst schema don't match: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + + assert( + !catalystType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]), + s"""User-defined types in Catalyst schema should have already been expanded: + |${catalystType.prettyJson} + """.stripMargin) + + logDebug( + s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -87,16 +163,18 @@ private[parquet] class CatalystRowConverter( override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) } + private val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + + private val unsafeProjection = UnsafeProjection.create(catalystType) + /** - * Represents the converted row object once an entire Parquet record is converted. - * - * @todo Uses [[UnsafeRow]] for better performance. + * The [[UnsafeRow]] converted from an entire Parquet record. */ - val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + def currentRecord: UnsafeRow = unsafeProjection(currentRow) // Converters for each field. - private val fieldConverters: Array[Converter] = { - parquetType.getFields.zip(catalystType).zipWithIndex.map { + private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { + parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) @@ -105,11 +183,19 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def end(): Unit = updater.set(currentRow) + override def end(): Unit = { + var i = 0 + while (i < currentRow.numFields) { + fieldConverters(i).updater.end() + i += 1 + } + updater.set(currentRow) + } override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { + fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } @@ -122,33 +208,50 @@ private[parquet] class CatalystRowConverter( private def newConverter( parquetType: Type, catalystType: DataType, - updater: ParentContainerUpdater): Converter = { + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new CatalystPrimitiveConverter(updater) case ByteType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } + // For INT32 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new CatalystIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For INT64 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + new CatalystLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals + case t: DecimalType + if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || + parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => + new CatalystBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + case t: DecimalType => - new CatalystDecimalConverter(t, updater) + throw new RuntimeException( + s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " + + s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " + + "FIXED_LEN_BYTE_ARRAY, or BINARY.") case StringType => new CatalystStringConverter(updater) case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -164,13 +267,23 @@ private[parquet] class CatalystRowConverter( } case DateType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) } } + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + case t: ArrayType => new CatalystArrayConverter(parquetType.asGroupType(), t, updater) @@ -182,40 +295,18 @@ private[parquet] class CatalystRowConverter( override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) - case t: UserDefinedType[_] => - val catalystTypeForUDT = t.sqlType - val nullable = parquetType.isRepetition(Repetition.OPTIONAL) - val field = StructField("udt", catalystTypeForUDT, nullable) - val parquetTypeForUDT = new CatalystSchemaConverter().convertField(field) - newConverter(parquetTypeForUDT, catalystTypeForUDT, updater) - - case _ => + case t => throw new RuntimeException( - s"Unable to create Parquet converter for data type ${catalystType.json}") + s"Unable to create Parquet converter for data type ${t.json} " + + s"whose Parquet type is $parquetType") } } - /** - * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types - * are handled by this converter. Parquet primitive types are only a subset of those of Spark - * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. - */ - private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { - - override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) - override def addInt(value: Int): Unit = updater.setInt(value) - override def addLong(value: Long): Unit = updater.setLong(value) - override def addFloat(value: Float): Unit = updater.setFloat(value) - override def addDouble(value: Double): Unit = updater.setDouble(value) - override def addBinary(value: Binary): Unit = updater.set(value.getBytes) - } - /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -232,17 +323,30 @@ private[parquet] class CatalystRowConverter( } override def addBinary(value: Binary): Unit = { - updater.set(UTF8String.fromBytes(value.getBytes)) + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we + // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying + // it. + val buffer = value.toByteBuffer + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() + updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) } } /** * Parquet converter for fixed-precision decimals. */ - private final class CatalystDecimalConverter( - decimalType: DecimalType, - updater: ParentContainerUpdater) - extends PrimitiveConverter { + private abstract class CatalystDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystPrimitiveConverter(updater) { + + protected var expandedDictionary: Array[Decimal] = _ + + override def hasDictionarySupport: Boolean = true + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { @@ -251,35 +355,59 @@ private[parquet] class CatalystRowConverter( // Converts decimals stored as INT64 override def addLong(value: Long): Unit = { - updater.set(Decimal(value, decimalType.precision, decimalType.scale)) + updater.set(decimalFromLong(value)) } // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY override def addBinary(value: Binary): Unit = { - updater.set(toDecimal(value)) + updater.set(decimalFromBinary(value)) } - private def toDecimal(value: Binary): Decimal = { - val precision = decimalType.precision - val scale = decimalType.scale - val bytes = value.getBytes + protected def decimalFromLong(value: Long): Decimal = { + Decimal(value, precision, scale) + } - if (precision <= 8) { + protected def decimalFromBinary(value: Binary): Decimal = { + if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - var unscaled = 0L - var i = 0 - - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. - Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) + } + } + } + + private class CatalystIntDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToInt(id).toLong) + } + } + } + + private class CatalystLongDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToLong(id)) + } + } + } + + private class CatalystBinaryDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromBinary(dictionary.decodeToBinary(id)) } } } @@ -306,15 +434,16 @@ private[parquet] class CatalystRowConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ private val elementConverter: Converter = { val repeatedType = parquetSchema.getType(0) val elementType = catalystSchema.elementType + val parentName = parquetSchema.getName - if (isElementType(repeatedType, elementType)) { + if (isElementType(repeatedType, elementType, parentName)) { newConverter(repeatedType, elementType, new ParentContainerUpdater { override def set(value: Any): Unit = currentArray += value }) @@ -351,10 +480,13 @@ private[parquet] class CatalystRowConverter( * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules */ // scalastyle:on - private def isElementType(parquetRepeatedType: Type, catalystElementType: DataType): Boolean = { + private def isElementType( + parquetRepeatedType: Type, catalystElementType: DataType, parentName: String): Boolean = { (parquetRepeatedType, catalystElementType) match { case (t: PrimitiveType, _) => true case (t: GroupType, _) if t.getFieldCount > 1 => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == "array" => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == parentName + "_tuple" => true case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true case _ => false } @@ -383,7 +515,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -446,4 +578,85 @@ private[parquet] class CatalystRowConverter( } } } + + private trait RepeatedConverter { + private var currentArray: ArrayBuffer[Any] = _ + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } +} + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.arrayOffset() + buffer.position() + val end = buffer.arrayOffset() + buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index d43ca95b4eea..5f9f9083098a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.schema.OriginalType._ @@ -25,6 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} @@ -40,49 +41,31 @@ import org.apache.spark.sql.{AnalysisException, SQLConf} * @constructor * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. + * [[StructType]]. This argument only affects Parquet read path. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` - * described in Parquet format spec. - * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when - * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and - * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and - * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is - * backwards-compatible with these settings. If this argument is set to `false`, we fallback - * to old style non-standard behaviors. + * described in Parquet format spec. This argument only affects Parquet read path. + * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 + * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. + * When set to false, use standard format defined in parquet-format spec. This argument only + * affects Parquet write path. */ private[parquet] class CatalystSchemaConverter( - private val assumeBinaryIsString: Boolean, - private val assumeInt96IsTimestamp: Boolean, - private val followParquetFormatSpec: Boolean) { - - // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in - // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. - def this() = this( - assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - followParquetFormatSpec = conf.followParquetFormatSpec) + writeLegacyParquetFormat = conf.writeLegacyParquetFormat) def this(conf: Configuration) = this( - assumeBinaryIsString = - conf.getBoolean( - SQLConf.PARQUET_BINARY_AS_STRING.key, - SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), - assumeInt96IsTimestamp = - conf.getBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), - followParquetFormatSpec = - conf.getBoolean( - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -90,7 +73,7 @@ private[parquet] class CatalystSchemaConverter( def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.map { field => + val fields = parquetSchema.getFields.asScala.map { field => field.getRepetition match { case OPTIONAL => StructField(field.getName, convertField(field), nullable = true) @@ -99,8 +82,11 @@ private[parquet] class CatalystSchemaConverter( StructField(field.getName, convertField(field), nullable = false) case REPEATED => - throw new AnalysisException( - s"REPEATED not supported outside LIST or MAP. Type: $field") + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertField(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) } } @@ -122,6 +108,9 @@ private[parquet] class CatalystSchemaConverter( def typeString = if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + def typeNotSupported() = + throw new AnalysisException(s"Parquet type not supported: $typeString") + def typeNotImplemented() = throw new AnalysisException(s"Parquet type not yet supported: $typeString") @@ -135,7 +124,7 @@ private[parquet] class CatalystSchemaConverter( val precision = field.getDecimalMetadata.getPrecision val scale = field.getDecimalMetadata.getScale - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") @@ -155,7 +144,10 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -163,13 +155,14 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() } case INT96 => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( assumeInt96IsTimestamp, "INT96 is not supported unless it's interpreted as timestamp. " + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") @@ -177,9 +170,10 @@ private[parquet] class CatalystSchemaConverter( case BINARY => originalType match { - case UTF8 | ENUM => StringType + case UTF8 | ENUM | JSON => StringType case null if assumeBinaryIsString => StringType case null => BinaryType + case BSON => BinaryType case DECIMAL => makeDecimalType() case _ => illegalType() } @@ -211,11 +205,11 @@ private[parquet] class CatalystSchemaConverter( // // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists case LIST => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( field.getFieldCount == 1, s"Invalid list type $field") val repeatedType = field.getType(0) - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") if (isElementType(repeatedType, field.getName)) { @@ -231,17 +225,17 @@ private[parquet] class CatalystSchemaConverter( // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 // scalastyle:on case MAP | MAP_KEY_VALUE => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( field.getFieldCount == 1 && !field.getType(0).isPrimitive, s"Invalid map type: $field") val keyValueType = field.getType(0).asGroupType() - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, s"Invalid map type: $field") val keyType = keyValueType.getType(0) - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( keyType.isPrimitive, s"Map key type is expected to be a primitive type, but found: $keyType") @@ -313,7 +307,10 @@ private[parquet] class CatalystSchemaConverter( * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. */ def convert(catalystSchema: StructType): MessageType = { - Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + Types + .buildMessage() + .addFields(catalystSchema.map(convertField): _*) + .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) } /** @@ -361,10 +358,10 @@ private[parquet] class CatalystSchemaConverter( // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. // // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond - // timestamp in Impala for some historical reasons, it's not recommended to be used for any - // other types and will probably be deprecated in future Parquet format spec. That's the - // reason why Parquet format spec only defines `TIMESTAMP_MILLIS` and `TIMESTAMP_MICROS` which - // are both logical types annotating `INT64`. + // timestamp in Impala for some historical reasons. It's not recommended to be used for any + // other types and will probably be deprecated in some future version of parquet-format spec. + // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and + // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. // // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store @@ -375,22 +372,22 @@ private[parquet] class CatalystSchemaConverter( // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) // hasn't implemented `TIMESTAMP_MICROS` yet. // - // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. case TimestampType => Types.primitive(INT96, repetition).named(field.name) case BinaryType => Types.primitive(BINARY, repetition).named(field.name) - // ===================================== - // Decimals (for Spark version <= 1.4.x) - // ===================================== + // ====================== + // Decimals (legacy mode) + // ====================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and // always store decimals in fixed-length byte arrays. To keep compatibility with these older // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated // by `DECIMAL`. - case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => + case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) @@ -399,13 +396,13 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - // ===================================== - // Decimals (follow Parquet format spec) - // ===================================== + // ======================== + // Decimals (standard mode) + // ======================== // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -415,7 +412,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -424,7 +421,7 @@ private[parquet] class CatalystSchemaConverter( .named(field.name) // Uses FIXED_LEN_BYTE_ARRAY for all other precisions - case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) @@ -433,17 +430,18 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - // =================================================== - // ArrayType and MapType (for Spark versions <= 1.4.x) - // =================================================== + // =================================== + // ArrayType and MapType (legacy mode) + // =================================== - // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level - // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is - // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level + // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro + // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element + // field name "array" is borrowed from parquet-avro. + case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => // group (LIST) { // optional group bag { - // repeated element; + // repeated array; // } // } ConversionPatterns.listType( @@ -452,13 +450,13 @@ private[parquet] class CatalystSchemaConverter( Types .buildGroup(REPEATED) // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) - .addField(convertField(StructField("array_element", elementType, nullable))) - .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + .addField(convertField(StructField("array", elementType, nullable))) + .named("bag")) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => // group (LIST) { // repeated element; // } @@ -470,7 +468,7 @@ private[parquet] class CatalystSchemaConverter( // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. - case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => // group (MAP) { // repeated group map (MAP_KEY_VALUE) { // required key; @@ -483,11 +481,11 @@ private[parquet] class CatalystSchemaConverter( convertField(StructField("key", keyType, nullable = false)), convertField(StructField("value", valueType, valueContainsNull))) - // ================================================== - // ArrayType and MapType (follow Parquet format spec) - // ================================================== + // ===================================== + // ArrayType and MapType (standard mode) + // ===================================== - case ArrayType(elementType, containsNull) if followParquetFormatSpec => + case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => // group (LIST) { // repeated group list { // element; @@ -534,21 +532,14 @@ private[parquet] class CatalystSchemaConverter( throw new AnalysisException(s"Unsupported data type $field.dataType") } } - - // Max precision of a decimal value stored in `numBytes` bytes - private def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } - private[parquet] object CatalystSchemaConverter { + val SPARK_PARQUET_SCHEMA_NAME = "spark_schema" + def checkFieldName(name: String): Unit = { // ,;{}()\n\t= and space are special characters in Parquet schema - analysisRequire( + checkConversionRequirement( !name.matches(".*[ ,;{}()\n\t=].*"), s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. @@ -560,7 +551,7 @@ private[parquet] object CatalystSchemaConverter { schema } - def analysisRequire(f: => Boolean, message: String): Unit = { + def checkConversionRequirement(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message) } @@ -574,14 +565,18 @@ private[parquet] object CatalystSchemaConverter { numBytes } - private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) - // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - def minBytesForPrecision(precision : Int) : Int = { - if (precision < MIN_BYTES_FOR_PRECISION.length) { - MIN_BYTES_FOR_PRECISION(precision) - } else { - computeMinBytesForPrecision(precision) - } + val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ + + val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ + + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala new file mode 100644 index 000000000000..6862dea5e6c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.{ByteBuffer, ByteOrder} +import java.util + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.io.api.{Binary, RecordConsumer} + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, minBytesForPrecision} +import org.apache.spark.sql.types._ + +/** + * A Parquet [[WriteSupport]] implementation that writes Catalyst [[InternalRow]]s as Parquet + * messages. This class can write Parquet data in two modes: + * + * - Standard mode: Parquet data are written in standard format defined in parquet-format spec. + * - Legacy mode: Parquet data are written in legacy format compatible with Spark 1.4 and prior. + * + * This behavior can be controlled by SQL option `spark.sql.parquet.writeLegacyFormat`. The value + * of this option is propagated to this class by the `init()` method and its Hadoop configuration + * argument. + */ +private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] with Logging { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. + // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access + // data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // Schema of the `InternalRow`s to be written + private var schema: StructType = _ + + // `ValueWriter`s for all fields of the schema + private var rootFieldWriters: Seq[ValueWriter] = _ + + // The Parquet `RecordConsumer` to which all `InternalRow`s are written + private var recordConsumer: RecordConsumer = _ + + // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions + private var writeLegacyParquetFormat: Boolean = _ + + // Reusable byte array used to write timestamps as Parquet INT96 values + private val timestampBuffer = new Array[Byte](12) + + // Reusable byte array used to write decimal values + private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + + override def init(configuration: Configuration): WriteContext = { + val schemaString = configuration.get(CatalystWriteSupport.SPARK_ROW_SCHEMA) + this.schema = StructType.fromString(schemaString) + this.writeLegacyParquetFormat = { + // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation + assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) + configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean + } + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + + val messageType = new CatalystSchemaConverter(configuration).convert(schema) + val metadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + + logInfo( + s"""Initialized Parquet WriteSupport with Catalyst schema: + |${schema.prettyJson} + |and corresponding Parquet message type: + |$messageType + """.stripMargin) + + new WriteContext(messageType, metadata) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + + override def write(row: InternalRow): Unit = { + consumeMessage { + writeFields(row, schema, rootFieldWriters) + } + } + + private def writeFields( + row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + if (!row.isNullAt(i)) { + consumeField(schema(i).name, i) { + fieldWriters(i).apply(row, i) + } + } + i += 1 + } + } + + private def makeWriter(dataType: DataType): ValueWriter = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getShort(ordinal)) + + case IntegerType | DateType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addLong(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addFloat(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addDouble(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromByteArray(row.getUTF8String(ordinal).getBytes)) + + case TimestampType => + (row: SpecializedGetters, ordinal: Int) => { + // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it + // Currently we only support timestamps stored as INT96, which is compatible with Hive + // and Impala. However, INT96 is to be deprecated. We plan to support `TIMESTAMP_MICROS` + // defined in the parquet-format spec. But up until writing, the most recent parquet-mr + // version (1.8.1) hasn't implemented it yet. + + // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond + // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromByteArray(timestampBuffer)) + } + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromByteArray(row.getBinary(ordinal))) + + case DecimalType.Fixed(precision, scale) => + makeDecimalWriter(precision, scale) + + case t: StructType => + val fieldWriters = t.map(_.dataType).map(makeWriter) + (row: SpecializedGetters, ordinal: Int) => + consumeGroup { + writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) + } + + case t: ArrayType => makeArrayWriter(t) + + case t: MapType => makeMapWriter(t) + + case t: UserDefinedType[_] => makeWriter(t.sqlType) + + // TODO Adds IntervalType support + case _ => sys.error(s"Unsupported data type $dataType.") + } + } + + private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { + assert( + precision <= DecimalType.MAX_PRECISION, + s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") + + val numBytes = minBytesForPrecision(precision) + + val int32Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addInteger(unscaledLong.toInt) + } + + val int64Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addLong(unscaledLong) + } + + val binaryWriterUsingUnscaledLong = + (row: SpecializedGetters, ordinal: Int) => { + // When the precision is low enough (<= 18) to squeeze the decimal value into a `Long`, we + // can build a fixed-length byte array with length `numBytes` using the unscaled `Long` + // value and the `decimalBuffer` for better performance. + val unscaled = row.getDecimal(ordinal, precision, scale).toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + + while (i < numBytes) { + decimalBuffer(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + + recordConsumer.addBinary(Binary.fromByteArray(decimalBuffer, 0, numBytes)) + } + + val binaryWriterUsingUnscaledBytes = + (row: SpecializedGetters, ordinal: Int) => { + val decimal = row.getDecimal(ordinal, precision, scale) + val bytes = decimal.toJavaBigDecimal.unscaledValue().toByteArray + val fixedLengthBytes = if (bytes.length == numBytes) { + // If the length of the underlying byte array of the unscaled `BigInteger` happens to be + // `numBytes`, just reuse it, so that we don't bother copying it to `decimalBuffer`. + bytes + } else { + // Otherwise, the length must be less than `numBytes`. In this case we copy contents of + // the underlying bytes with padding sign bytes to `decimalBuffer` to form the result + // fixed-length byte array. + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + + recordConsumer.addBinary(Binary.fromByteArray(fixedLengthBytes, 0, numBytes)) + } + + writeLegacyParquetFormat match { + // Standard mode, 1 <= precision <= 9, writes as INT32 + case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + + // Standard mode, 10 <= precision <= 18, writes as INT64 + case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + + // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY + case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + + // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY + case _ => binaryWriterUsingUnscaledBytes + } + } + + def makeArrayWriter(arrayType: ArrayType): ValueWriter = { + val elementWriter = makeWriter(arrayType.elementType) + + def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < array.numElements()) { + consumeGroup { + // Only creates the element field if the current array element is not null. + if (!array.isNullAt(i)) { + consumeField(elementFieldName, 0) { + elementWriter.apply(array, i) + } + } + } + i += 1 + } + } + } + } + } + + def twoLevelArrayWriter(repeatedFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedFieldName, 0) { + var i = 0 + while (i < array.numElements()) { + elementWriter.apply(array, i) + i += 1 + } + } + } + } + } + + (writeLegacyParquetFormat, arrayType.containsNull) match { + case (legacyMode @ false, _) => + // Standard mode: + // + // group (LIST) { + // repeated group list { + // ^~~~ repeatedGroupName + // element; + // ^~~~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "list", elementFieldName = "element") + + case (legacyMode @ true, nullableElements @ true) => + // Legacy mode, with nullable elements: + // + // group (LIST) { + // optional group bag { + // ^~~ repeatedGroupName + // repeated array; + // ^~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "bag", elementFieldName = "array") + + case (legacyMode @ true, nullableElements @ false) => + // Legacy mode, with non-nullable elements: + // + // group (LIST) { + // repeated array; + // ^~~~~ repeatedFieldName + // } + twoLevelArrayWriter(repeatedFieldName = "array") + } + } + + private def makeMapWriter(mapType: MapType): ValueWriter = { + val keyWriter = makeWriter(mapType.keyType) + val valueWriter = makeWriter(mapType.valueType) + val repeatedGroupName = if (writeLegacyParquetFormat) { + // Legacy mode: + // + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // ^~~ repeatedGroupName + // required key; + // value; + // } + // } + "map" + } else { + // Standard mode: + // + // group (MAP) { + // repeated group key_value { + // ^~~~~~~~~ repeatedGroupName + // required key; + // value; + // } + // } + "key_value" + } + + (row: SpecializedGetters, ordinal: Int) => { + val map = row.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + + consumeGroup { + // Only creates the repeated field if the map is non-empty. + if (map.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < map.numElements()) { + consumeGroup { + consumeField("key", 0) { + keyWriter.apply(keyArray, i) + } + + // Only creates the "value" field if the value if non-empty + if (!map.valueArray().isNullAt(i)) { + consumeField("value", 1) { + valueWriter.apply(valueArray, i) + } + } + } + i += 1 + } + } + } + } + } + } + + private def consumeMessage(f: => Unit): Unit = { + recordConsumer.startMessage() + f + recordConsumer.endMessage() + } + + private def consumeGroup(f: => Unit): Unit = { + recordConsumer.startGroup() + f + recordConsumer.endGroup() + } + + private def consumeField(field: String, index: Int)(f: => Unit): Unit = { + recordConsumer.startField(field, index) + f + recordConsumer.endField(field, index) + } +} + +private[parquet] object CatalystWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" + + def setSchema(schema: StructType, configuration: Configuration): Unit = { + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) + configuration.set(SPARK_ROW_SCHEMA, schema.json) + configuration.setIfUnset( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1551afd7b7bf..1a4e99ff10af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -39,9 +39,10 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO * * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are - * left * empty). + * left empty). */ -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +private[datasources] class DirectParquetOutputCommitter( + outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) @@ -53,7 +54,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) + val configuration = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(jobContext) + // scalastyle:on jobcontext + } val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala similarity index 55% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d57b789f5c1c..07714329370a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -15,35 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable -import java.nio.ByteBuffer -import com.google.common.io.BaseEncoding -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.compat.FilterCompat -import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} -import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.OriginalType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { - val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - - def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap { filter => - createFilter(filter) - }.reduceOption(FilterApi.and).map(FilterCompat.get) - } - case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { @@ -68,15 +53,18 @@ private[sql] object ParquetFilters { case DoubleType => (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* // Binary.fromString and Binary.fromByteArray don't accept null values case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + */ } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -90,14 +78,18 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + */ } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -109,12 +101,17 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.lt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + */ } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -126,12 +123,17 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.ltEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + */ } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -143,12 +145,17 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + */ } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -160,12 +167,17 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) case DoubleType => (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gtEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + */ } private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = { @@ -181,14 +193,18 @@ private[sql] object ParquetFilters { case DoubleType => (n: String, v: Set[Any]) => FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]])) + + // See https://issues.apache.org/jira/browse/SPARK-11153 + /* case StringType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) case BinaryType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]])))) + */ } /** @@ -197,11 +213,23 @@ private[sql] object ParquetFilters { def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + relaxParquetValidTypeMap + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, // which can be casted to `false` implicitly. Please refer to the `eval` method of these // operators and the `SimplifyFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + predicate match { case sources.IsNull(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) @@ -213,6 +241,11 @@ private[sql] object ParquetFilters { case sources.Not(sources.EqualTo(name, value)) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.EqualNullSafe(name, value) => + makeEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.Not(sources.EqualNullSafe(name, value)) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.LessThan(name, value) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) case sources.LessThanOrEqual(name, value) => @@ -239,122 +272,34 @@ private[sql] object ParquetFilters { } } - /** - * Converts Catalyst predicate expressions to Parquet filter predicates. - * - * @todo This can be removed once we get rid of the old Parquet support. - */ - def createFilter(predicate: Expression): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `SimplifyFilters` rule for details. - predicate match { - case IsNull(NamedExpression(name, dataType)) => - makeEq.lift(dataType).map(_(name, null)) - case IsNotNull(NamedExpression(name, dataType)) => - makeNotEq.lift(dataType).map(_(name, null)) - - case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeEq.lift(dataType).map(_(name, value)) - - case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - - case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGt.lift(dataType).map(_(name, value)) - - case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - - case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLt.lift(dataType).map(_(name, value)) - - case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - - case And(lhs, rhs) => - (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) - - case Or(lhs, rhs) => - for { - lhsFilter <- createFilter(lhs) - rhsFilter <- createFilter(rhs) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case Not(pred) => - createFilter(pred).map(FilterApi.not) - - case InSet(NamedExpression(name, dataType), valueSet) => - makeInSet.lift(dataType).map(_(name, valueSet)) - - case _ => None - } - } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { - if (filters.nonEmpty) { - val serialized: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(filters).array() - val encoded: String = BaseEncoding.base64().encode(serialized) - conf.set(PARQUET_FILTER_DATA, encoded) - } - } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = { - val data = conf.get(PARQUET_FILTER_DATA) - if (data != null) { - val decoded: Array[Byte] = BaseEncoding.base64().decode(data) - SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(decoded)) - } else { - Seq() - } + // !! HACK ALERT !! + // + // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to + // parquet-mr 1.8.1 or higher versions. + // + // In Parquet, not all types of columns can be used for filter push-down optimization. The set + // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and + // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be + // pushed down. + // + // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps + // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus, + // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly + // legal except that it fails the `ValidTypeMap` check. + // + // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue. + private lazy val relaxParquetValidTypeMap: Unit = { + val constructor = Class + .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor") + .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType]) + + constructor.setAccessible(true) + val enumTypeDescriptor = constructor + .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM) + .asInstanceOf[AnyRef] + + val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get + addMethod.setAccessible(true) + addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala similarity index 76% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b4337a48dbd8..1af2a394f399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -15,41 +15,46 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Level, Logger => JLogger} +import java.util.logging.{Logger => JLogger} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetRecordReader, _} import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ParquetLog} +import org.apache.parquet.{Log => ApacheParquetLog} +import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} -import org.apache.spark.rdd.RDD._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} + +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "parquet" -private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], @@ -62,7 +67,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriterInternal { + extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -77,8 +82,10 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } @@ -87,7 +94,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } @@ -100,7 +109,7 @@ private[sql] class ParquetRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( @@ -137,6 +146,12 @@ private[sql] class ParquetRelation( meta } + override def toString: String = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map { tableName => + s"${getClass.getSimpleName}: $tableName" + }.getOrElse(super.toString) + } + override def equals(other: Any): Boolean = other match { case that: ParquetRelation => val schemaEquality = if (shouldMergeSchemas) { @@ -202,7 +217,18 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) + val conf = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(job) + // scalastyle:on jobcontext + } + + // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible + val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassName == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[DirectParquetOutputCommitter].getCanonicalName) + } val committerClass = conf.getClass( @@ -228,18 +254,22 @@ private[sql] class ParquetRelation( // bundled with `ParquetOutputFormat[Row]`. job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - // TODO There's no need to use two kinds of WriteSupport - // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and - // complex types. - val writeSupportClass = - if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } + ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) + CatalystWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) + // and `CatalystWriteSupport` (writing actual rows to Parquet files). + conf.set( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlContext.conf.isParquetBinaryAsString.toString) + + conf.set( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlContext.conf.isParquetINT96AsTimestamp.toString) - ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) - RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlContext.conf.writeLegacyParquetFormat.toString) // Sets compression scheme conf.set( @@ -258,16 +288,24 @@ private[sql] class ParquetRelation( } } - override def buildScan( + override def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + + // When merging schemas is enabled and the column of the given filter does not exist, + // Parquet emits an exception which is an issue of Parquet (PARQUET-389). + val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown + + // Parquet row group size. We will use this value as the value for + // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value + // of these flags are smaller than the parquet row group size. + val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = @@ -275,23 +313,23 @@ private[sql] class ParquetRelation( requiredColumns, filters, dataSchema, + parquetBlockSize, useMetadataCache, - parquetFilterPushDown, + safeParquetFilterPushDown, assumeBinaryIsString, - assumeInt96IsTimestamp, - followParquetFormatSpec) _ + assumeInt96IsTimestamp) _ // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + val setInputPaths = + ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, + sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - keyClass = classOf[Void], valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -317,7 +355,7 @@ private[sql] class ParquetRelation( override def getPartitions: Array[SparkPartition] = { val inputFormat = new ParquetInputFormat[InternalRow] { override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) } } @@ -325,10 +363,11 @@ private[sql] class ParquetRelation( val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + new SqlNewHadoopPartition( + id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) } } - }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + } } } @@ -350,7 +389,7 @@ private[sql] class ParquetRelation( var schema: StructType = _ // Cached leaves - var cachedLeaves: Set[FileStatus] = null + var cachedLeaves: mutable.LinkedHashSet[FileStatus] = null /** * Refreshes `FileStatus`es, footers, partition spec, and table schema. @@ -363,13 +402,13 @@ private[sql] class ParquetRelation( !cachedLeaves.equals(currentLeafStatuses) if (leafStatusesChanged) { - cachedLeaves = currentLeafStatuses.toIterator.toSet + cachedLeaves = currentLeafStatuses // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = currentLeafStatuses.filter { f => isSummaryFile(f.getPath) || !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray + }.toArray.sortBy(_.getPath.toString) dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) metadataStatuses = @@ -432,13 +471,30 @@ private[sql] class ParquetRelation( // You should enable this configuration only if you are very sure that for the parquet // part-files to read there are corresponding summary files containing correct schema. + // As filed in SPARK-11500, the order of files to touch is a matter, which might affect + // the ordering of the output columns. There are several things to mention here. + // + // 1. If mergeRespectSummaries config is false, then it merges schemas by reducing from + // the first part-file so that the columns of the lexicographically first file show + // first. + // + // 2. If mergeRespectSummaries config is true, then there should be, at least, + // "_metadata"s for all given files, so that we can ensure the columns of + // the lexicographically first file show first. + // + // 3. If shouldMergeSchemas is false, but when multiple files are given, there is + // no guarantee of the output order, since there might not be a summary file for the + // lexicographically first file, which ends up putting ahead the columns of + // the other files. However, this should be okay since not enabling + // shouldMergeSchemas means (assumes) all the files have the same schemas. + val needMerged: Seq[FileStatus] = if (mergeRespectSummaries) { Seq() } else { dataStatuses } - (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq + needMerged ++ metadataStatuses ++ commonMetadataStatuses } else { // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet // don't have this. @@ -471,17 +527,44 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + // If a ParquetRelation is converted from a Hive metastore table, this option is set to the + // original Hive table name. + private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" + + /** + * If parquet's block size (row group size) setting is larger than the min split size, + * we use parquet's block size setting as the min split size. Otherwise, we will create + * tasks processing nothing (because a split does not cover the starting point of a + * parquet block). See https://issues.apache.org/jira/browse/SPARK-10143 for more information. + */ + private def overrideMinSplitSize(parquetBlockSize: Long, conf: Configuration): Unit = { + val minSplitSize = + math.max( + conf.getLong("mapred.min.split.size", 0L), + conf.getLong("mapreduce.input.fileinputformat.split.minsize", 0L)) + if (parquetBlockSize > minSplitSize) { + val message = + s"Parquet's block size (row group size) is larger than " + + s"mapred.min.split.size/mapreduce.input.fileinputformat.split.minsize. Setting " + + s"mapred.min.split.size and mapreduce.input.fileinputformat.split.minsize to " + + s"$parquetBlockSize." + logDebug(message) + conf.set("mapred.min.split.size", parquetBlockSize.toString) + conf.set("mapreduce.input.fileinputformat.split.minsize", parquetBlockSize.toString) + } + } + /** This closure sets various Parquet configurations at both driver side and executor side. */ private[parquet] def initializeLocalJobFunc( requiredColumns: Array[String], filters: Array[Filter], dataSchema: StructType, + parquetBlockSize: Long, useMetadataCache: Boolean, parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, - assumeInt96IsTimestamp: Boolean, - followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration + assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -501,26 +584,30 @@ private[sql] object ParquetRelation extends Logging { }) conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, + CatalystWriteSupport.SPARK_ROW_SCHEMA, CatalystSchemaConverter.checkFieldNames(dataSchema).json) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) - // Sets flags for Parquet schema conversion + // Sets flags for `CatalystSchemaConverter` conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) - conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + + overrideMinSplitSize(parquetBlockSize, conf) } /** This closure sets input paths at the driver side. */ private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { + inputFiles: Array[FileStatus], + parquetBlockSize: Long)(job: Job): Unit = { // We side the input paths at the driver side. logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") if (inputFiles.nonEmpty) { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } + + overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) } private[parquet] def readSchema( @@ -530,7 +617,7 @@ private[sql] object ParquetRelation extends Logging { val converter = new CatalystSchemaConverter( sqlContext.conf.isParquetBinaryAsString, sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.followParquetFormatSpec) + sqlContext.conf.writeLegacyParquetFormat) converter.convert(schema) } @@ -540,7 +627,7 @@ private[sql] object ParquetRelation extends Logging { val metadata = footer.getParquetMetadata.getFileMetaData val serializedSchema = metadata .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. @@ -664,10 +751,10 @@ private[sql] object ParquetRelation extends Logging { filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - // HACK ALERT: + // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` @@ -697,21 +784,21 @@ private[sql] object ParquetRelation extends Logging { // Reads footers in multi-threaded manner within each task val footers = ParquetFileReader.readAllFootersInParallel( - serializedConf.value, fakeFileStatuses, skipRowGroups) + serializedConf.value, fakeFileStatuses.asJava, skipRowGroups).asScala // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` val converter = new CatalystSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) footers.map { footer => ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceOption(_ merge _).iterator + }.reduceLeftOption(_ merge _).iterator }.collect() - partiallyMergedSchemas.reduceOption(_ merge _) + partiallyMergedSchemas.reduceLeftOption(_ merge _) } /** @@ -724,7 +811,7 @@ private[sql] object ParquetRelation extends Logging { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) .flatMap(deserializeSchemaString) .getOrElse(converter.convert(fileMetaData.getSchema)) @@ -748,45 +835,46 @@ private[sql] object ParquetRelation extends Logging { }.toOption } - def enableLogForwarding() { - // Note: the org.apache.parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "org.apache.parquet". This - // checks first to see if there's any handlers already set - // and if not it creates them. If this method executes prior - // to that class being loaded then: - // 1) there's no handlers installed so there's none to - // remove. But when it IS finally loaded the desired affect - // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHandlers(false) - // undoing the attempt to override the logging here. - // - // Therefore we need to force the class to be loaded. - // This should really be resolved by Parquet. - Utils.classForName(classOf[ParquetLog].getName) - - // Note: Logger.getLogger("parquet") has a default logger - // that appends to Console which needs to be cleared. - val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) - parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - parquetLogger.setUseParentHandlers(true) - - // Disables a WARN log message in ParquetOutputCommitter. We first ensure that - // ParquetOutputCommitter is loaded and the static LOG field gets initialized. - // See https://issues.apache.org/jira/browse/SPARK-5968 for details - Utils.classForName(classOf[ParquetOutputCommitter].getName) - JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) - - // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. - // See https://issues.apache.org/jira/browse/PARQUET-220 for details - Utils.classForName(classOf[ParquetRecordReader[_]].getName) - JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) + // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. + // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep + // references to loggers in both parquet-mr <= 1.6 and >= 1.7 + val apacheParquetLogger: JLogger = JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName) + val parquetLogger: JLogger = JLogger.getLogger("parquet") + + // Parquet initializes its own JUL logger in a static block which always prints to stdout. Here + // we redirect the JUL logger via SLF4J JUL bridge handler. + val redirectParquetLogsViaSLF4J: Unit = { + def redirect(logger: JLogger): Unit = { + logger.getHandlers.foreach(logger.removeHandler) + logger.setUseParentHandlers(false) + logger.addHandler(new SLF4JBridgeHandler) + } + + // For parquet-mr 1.7.0 and above versions, which are under `org.apache.parquet` namespace. + // scalastyle:off classforname + Class.forName(classOf[ApacheParquetLog].getName) + // scalastyle:on classforname + redirect(JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName)) + + // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` + // namespace. + try { + // scalastyle:off classforname + Class.forName("parquet.Log") + // scalastyle:on classforname + redirect(JLogger.getLogger("parquet")) + } catch { case _: Throwable => + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly jar + // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block + // should be removed after this issue is fixed. + } } // The parquet compression short names val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, + "NONE" -> CompressionCodecName.UNCOMPRESSED, "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 11bb49b8d83d..1a8e7ab202dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,13 +17,37 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.analysis.{Catalog, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} + +/** + * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. + */ +private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedRelation if u.tableIdentifier.database.isDefined => + try { + val resolved = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = None, + partitionColumns = Array(), + provider = u.tableIdentifier.database.get, + options = Map("path" -> u.tableIdentifier.table)) + val plan = LogicalRelation(resolved.relation) + u.alias.map(a => Subquery(u.alias.get, plan)).getOrElse(plan) + } catch { + case e: ClassNotFoundException => u + case e: Exception => + // the provider is valid, but failed to create a logical plan + u.failAnalysis(e.getMessage) + } + } +} /** * A rule to do pre-insert data type casting and field renaming. Before we insert into @@ -37,7 +61,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation), _, child, _, _) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _), _, child, _, _) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -84,14 +108,14 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite, ifNotExists) => + l @ LogicalRelation(t: InsertableRelation, _), partition, query, overwrite, ifNotExists) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") } else { // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation) => src + case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(t)) { failAnalysis( @@ -101,7 +125,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } } - case logical.InsertIntoTable(LogicalRelation(r: HadoopFsRelation), part, _, _, _) => + case logical.InsertIntoTable( + LogicalRelation(r: HadoopFsRelation, _), part, query, overwrite, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionColumns.fieldNames.toSet @@ -115,6 +140,20 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes( + r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis) + + // Get all input data source relations of the query. + val srcRelations = query.collect { + case LogicalRelation(src: BaseRelation, _) => src + } + if (srcRelations.contains(r)) { + failAnalysis( + "Cannot insert overwrite into table that is also being read from.") + } else { + // OK + } + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") @@ -126,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => + case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). - case l @ LogicalRelation(dest: BaseRelation) => + case l @ LogicalRelation(dest: BaseRelation, _) => // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation) => src + case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableName that is also being read from.") + s"Cannot overwrite table $tableIdent that is also being read from.") } else { // OK } @@ -152,6 +191,9 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes( + query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + case _ => // OK } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala new file mode 100644 index 000000000000..4a1cbe4c38fa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{Path, FileStatus} +import org.apache.hadoop.io.{NullWritable, Text, LongWritable} +import org.apache.hadoop.mapred.{TextInputFormat, JobConf} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A data source for reading text files. + */ +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + dataSchema.foreach(verifySchema) + new TextRelation(None, partitionColumns, paths)(sqlContext) + } + + override def shortName(): String = "text" + + private def verifySchema(schema: StructType): Unit = { + if (schema.size != 1) { + throw new AnalysisException( + s"Text data source supports only a single column, and you have ${schema.size} columns.") + } + val tpe = schema(0).dataType + if (tpe != StringType) { + throw new AnalysisException( + s"Text data source supports only a string column, but you have ${tpe.simpleString}.") + } + } +} + +private[sql] class TextRelation( + val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + override val paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String]) + (@transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec, parameters) { + + /** Data schema is always a single column, named "value". */ + override def dataSchema: StructType = new StructType().add("value", StringType) + + /** This is an internal data source that outputs internal row format. */ + override val needConversion: Boolean = false + + + override private[sql] def buildInternalScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val paths = inputPaths.map(_.getPath).sortBy(_.toUri) + + if (paths.nonEmpty) { + FileInputFormat.setInputPaths(job, paths: _*) + } + + sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + .mapPartitions { iter => + val bufferHolder = new BufferHolder + val unsafeRowWriter = new UnsafeRowWriter + val unsafeRow = new UnsafeRow + + iter.map { case (_, line) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.initialize(bufferHolder, 1) + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) + unsafeRow + } + } + } + + /** Write path. */ + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new TextOutputWriter(path, dataSchema, context) + } + } + } + + override def equals(other: Any): Boolean = other match { + case that: TextRelation => + paths.toSet == that.paths.toSet && partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(paths.toSet, partitionColumns) + } +} + +class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) + extends OutputWriter + with SparkHadoopMapRedUtil { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + val utf8string = row.getUTF8String(0) + buffer.set(utf8string.getBytes) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c37007f1eece..74892e4e13fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,21 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.unsafe.types.UTF8String - import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, Logging} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.types._ +import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** - * :: DeveloperApi :: * Contains methods for debugging query execution. * * Usage: @@ -53,10 +48,8 @@ package object debug { } /** - * :: DeveloperApi :: * Augments [[DataFrame]]s with debug methods. */ - @DeveloperApi implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan @@ -72,23 +65,6 @@ package object debug { case _ => } } - - def typeCheck(): Unit = { - val plan = query.queryExecution.executedPlan - val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { - case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => - visited += new TreeNodeRef(s) - TypeCheck(s) - } - try { - logDebug(s"Results returned: ${debugPlan.execute().count()}") - } catch { - case e: Exception => - def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) - logDebug(s"Deepest Error: ${unwrap(e)}") - } - } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -148,76 +124,4 @@ package object debug { } } } - - /** - * Helper functions for checking that runtime types match a given schema. - */ - private[sql] object TypeCheck { - def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { - case (null, _) => - - case (row: InternalRow, StructType(fields)) => - row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (a: ArrayData, ArrayType(elemType, _)) => - a.foreach(elemType, (_, e) => { - typeCheck(e, elemType) - }) - case (m: MapData, MapType(keyType, valueType, _)) => - m.keyArray().foreach(keyType, (_, e) => { - typeCheck(e, keyType) - }) - m.valueArray().foreach(valueType, (_, e) => { - typeCheck(e, valueType) - }) - - case (_: Long, LongType) => - case (_: Int, IntegerType) => - case (_: UTF8String, StringType) => - case (_: Float, FloatType) => - case (_: Byte, ByteType) => - case (_: Short, ShortType) => - case (_: Boolean, BooleanType) => - case (_: Double, DoubleType) => - case (_: Int, DateType) => - case (_: Long, TimestampType) => - case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) - - case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") - } - } - - /** - * Augments [[DataFrame]]s with debug methods. - */ - private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { - import TypeCheck._ - - override def nodeName: String = "" - - /* Only required when defining this class in a REPL. - override def makeCopy(args: Array[Object]): this.type = - TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] - */ - - def output: Seq[Attribute] = child.output - - def children: List[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - child.execute().map { row => - try typeCheck(row, child.schema) catch { - case e: Exception => - sys.error( - s""" - |ERROR WHEN TYPE CHECKING QUERY - |============================== - |$e - |======== BAD TREE ============ - |$child - """.stripMargin) - } - row - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 624efc1b1d73..1d381e2eaef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,22 +20,21 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils +import org.apache.spark.{InternalAccumulator, TaskContext} /** - * :: DeveloperApi :: * Performs an inner hash join of two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ -@DeveloperApi case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -44,6 +43,11 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -58,25 +62,65 @@ case class BroadcastHashJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value + // for the same query. @transient - private val broadcastFuture = future { - // Note that we use .execute().collect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) - sparkContext.broadcast(hashed) - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) + private lazy val broadcastFuture = { + val numBuildRows = buildSide match { + case BuildLeft => longMetric("numLeftRows") + case BuildRight => longMetric("numRightRows") + } + + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .execute().collect() because we don't want to convert data to Scala + // types + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + sparkContext.broadcast(hashed) + } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) + } + + protected override def doPrepare(): Unit = { + broadcastFuture + } protected override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = buildSide match { + case BuildLeft => longMetric("numRightRows") + case BuildRight => longMetric("numLeftRows") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - hashJoin(streamedIter, broadcastRelation.value) + val hashedRelation = broadcastRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } } object BroadcastHashJoin { - private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a0efcc..ab81bd7b3fc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,23 +20,21 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.{InternalAccumulator, TaskContext} /** - * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ -@DeveloperApi case class BroadcastHashOuterJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -45,6 +43,11 @@ case class BroadcastHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -59,15 +62,54 @@ case class BroadcastHashOuterJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value + // for the same query. @transient - private val broadcastFuture = future { - // Note that we use .execute().collect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + private lazy val broadcastFuture = { + val numBuildRows = joinType match { + case RightOuter => longMetric("numLeftRows") + case LeftOuter => longMetric("numRightRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .execute().collect() because we don't want to convert data to Scala + // types + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) + sparkContext.broadcast(hashed) + } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) + } + + protected override def doPrepare(): Unit = { + broadcastFuture + } override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = joinType match { + case RightOuter => longMetric("numRightRows") + case LeftOuter => longMetric("numLeftRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -75,19 +117,29 @@ case class BroadcastHashOuterJoin( val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + + val resultProj = resultProjection joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) }) case RightOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) }) case x => @@ -97,9 +149,3 @@ case class BroadcastHashOuterJoin( } } } - -object BroadcastHashOuterJoin { - - private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a60593911f94..004407b2e692 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Build the right table's join keys into a HashSet, and iteratively go through the left * table, to find the if join keys are in the Hash set. */ -@DeveloperApi case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -36,22 +35,42 @@ case class BroadcastLeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val input = right.execute().map(_.copy()).collect() + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val input = right.execute().map { row => + numRightRows += 1 + row.copy() + }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) + val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) - left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + left.execute().mapPartitionsInternal { streamIter => + hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) + val hashRelation = + HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) - left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + left.execute().mapPartitionsInternal { streamIter => + val hashedRelation = broadcastedRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 83b726a8e289..aab177b2e842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,19 +17,16 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + -/** - * :: DeveloperApi :: - */ -@DeveloperApi case class BroadcastNestedLoopJoin( left: SparkPlan, right: SparkPlan, @@ -38,6 +35,11 @@ case class BroadcastNestedLoopJoin( condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + /** BuildRight means the right relation <=> the broadcast relation. */ private val (streamed, broadcast) = buildSide match { case BuildRight => (left, right) @@ -47,7 +49,7 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + private[this] def genResultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { @@ -65,8 +67,12 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => + case Inner => + // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case left.output ++ right.output + case x => // TODO support the Left Semi Join + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } @@ -74,36 +80,44 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val (numStreamedRows, numBuildRows) = buildSide match { + case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()) - .collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new CompactBuffer[InternalRow] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false + numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() + matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() + matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case _ => } i += 1 @@ -111,9 +125,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -122,11 +136,13 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - )(_ ++ _) + new BitSet(broadcastedRelation.value.size) + )(_ | _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection + /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[InternalRow] = { val buf: CompactBuffer[InternalRow] = new CompactBuffer() @@ -137,8 +153,8 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withLeft(leftNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withRight(rel(i))).copy() + if (!allIncludedBroadcastTuples.get(i)) { + buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 } @@ -146,8 +162,8 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withRight(rightNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + if (!allIncludedBroadcastTuples.get(i)) { + buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 } @@ -158,6 +174,12 @@ case class BroadcastNestedLoopJoin( // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), + sparkContext.makeRDD(broadcastRowsWithNulls) + ).map { row => + // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. + numOutputRows += 1 + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 261b4724159f..fa2bc7672131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,26 +17,101 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + /** - * :: DeveloperApi :: - */ -@DeveloperApi + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, + iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} + + case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().map { row => + numLeftRows += 1 + row.asInstanceOf[UnsafeRow] + } + val rightResults = right.execute().map { row => + numRightRows += 1 + row.asInstanceOf[UnsafeRow] + } - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) + iter.map { r => + numOutputRows += 1 + joiner.join(r._1, r._2) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 6b3d1652923f..fb961d97c3c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashJoin { @@ -43,32 +44,21 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - @transient protected lazy val buildSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + protected def buildSideKeyGenerator: Projection = + UnsafeProjection.create(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newMutableProjection(streamedKeys, streamedPlan.output)() - } + protected def streamSideKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedPlan.output) protected def hashJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ @@ -77,13 +67,8 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) private[this] val joinKeys = streamSideKeyGenerator @@ -97,6 +82,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 + numOutputRows += 1 resultProjection(ret) } @@ -112,6 +98,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() + numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index a323aea4ea2c..ed626fef56af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution.joins -import java.util.{HashMap => JavaHashMap} - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer -@DeveloperApi + trait HashOuterJoin { self: SparkPlan => @@ -66,38 +64,18 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && joinType != FullOuter - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) - } + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + protected def buildKeyGenerator: Projection = + UnsafeProjection.create(buildKeys, buildPlan.output) - @transient protected lazy val buildKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + protected[this] def streamedKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedPlan.output) - @transient protected[this] lazy val streamedKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newProjection(streamedKeys, streamedPlan.output) - } - } - - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + protected[this] def resultProjection: InternalRow => InternalRow = + UnsafeProjection.create(self.schema) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -113,23 +91,30 @@ trait HashOuterJoin { protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, - rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } ret.iterator @@ -138,32 +123,42 @@ trait HashOuterJoin { protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } ret.iterator } protected[this] def fullOuterIterator( - key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + key: InternalRow, + leftIter: Iterable[InternalRow], + rightIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -176,10 +171,11 @@ trait HashOuterJoin { // append them directly case (r, idx) if boundCondition(joinedRow.withRight(r)) => + numOutputRows += 1 matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy() + resultProjection(joinedRow) } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -188,7 +184,8 @@ trait HashOuterJoin { // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy() + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. @@ -196,13 +193,16 @@ trait HashOuterJoin { // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => - joinedRow(leftNullRow, r).copy() + numOutputRows += 1 + resultProjection(joinedRow(leftNullRow, r)) } } else { leftIter.iterator.map[InternalRow] { l => - joinedRow(l, rightNullRow).copy() + numOutputRows += 1 + resultProjection(joinedRow(l, rightNullRow)) } ++ rightIter.iterator.map[InternalRow] { r => - joinedRow(leftNullRow, r).copy() + numOutputRows += 1 + resultProjection(joinedRow(leftNullRow, r)) } } } @@ -210,10 +210,12 @@ trait HashOuterJoin { // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], - keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() + numIterRows: LongSQLMetric, + keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { + val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() + numIterRows += 1 val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 97fde8f975bf..f23a1830e91c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashSemiJoin { @@ -32,42 +33,28 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema) - && UnsafeProjection.canSupport(right.schema)) - } + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe - override def canProcessSafeRows: Boolean = !supportUnsafe + protected def leftKeyGenerator: Projection = + UnsafeProjection.create(leftKeys, left.output) - @transient protected lazy val leftKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newMutableProjection(leftKeys, left.output)() - } - - @transient protected lazy val rightKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newMutableProjection(rightKeys, right.output)() - } + protected def rightKeyGenerator: Projection = + UnsafeProjection.create(rightKeys, right.output) @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { - currentRow = buildIter.next() + val currentRow = buildIter.next() + numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -76,30 +63,41 @@ trait HashSemiJoin { } } } + hashSet } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashSet: java.util.Set[InternalRow], + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { + numStreamRows += 1 val key = joinKeys(current) - !key.anyNull && hashSet.contains(key) + val r = !key.anyNull && hashSet.contains(key) + if (r) numOutputRows += 1 + r }) } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => + numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } + if (r) numOutputRows += 1 + r } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8bbfd2f894..8c7099ab5a34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,27 +17,29 @@ package org.apache.spark.sql.execution.joins -import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.shuffle.ShuffleMemoryManager -import org.apache.spark.{SparkConf, SparkEnv, TaskContext} +import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.sql.execution.local.LocalNode +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} -import org.apache.spark.util.Utils +import org.apache.spark.unsafe.memory.MemoryLocation +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.{SparkConf, SparkEnv} /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[joins] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation { def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by @@ -65,7 +67,8 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) @@ -87,7 +90,8 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) @@ -108,15 +112,21 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. -private[joins] object HashedRelation { +private[execution] object HashedRelation { + + def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { + apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + } def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + return UnsafeHashedRelation( + input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -130,6 +140,7 @@ private[joins] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() + numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -178,20 +189,51 @@ private[joins] object HashedRelation { */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) - extends HashedRelation with Externalizable { + extends HashedRelation + with KnownSizeEstimation + with Externalizable { private[joins] def this() = this(null) // Needed for serialization // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + // This is used in broadcast joins and distributed mode only @transient private[this] var binaryMap: BytesToBytesMap = _ + /** + * Return the size of the unsafe map on the executors. + * + * For broadcast joins, this hashed relation is bigger on the driver because it is + * represented as a Java hash map there. While serializing the map to the executors, + * however, we rehash the contents in a binary map to reduce the memory footprint on + * the executors. + * + * For non-broadcast joins or in local mode, return 0. + */ + def getUnsafeSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + 0 + } + } + + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] if (binaryMap != null) { // Used in Broadcast join - val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes) + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() @@ -199,8 +241,8 @@ private[joins] final class UnsafeHashedRelation( var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength while (offset < last) { - val numFields = PlatformDependent.UNSAFE.getInt(base, offset) - val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + val numFields = Platform.getInt(base, offset) + val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 val row = new UnsafeRow @@ -214,46 +256,73 @@ private[joins] final class UnsafeHashedRelation( } } else { - // Use the JavaHashMap in Local mode or ShuffleHashJoin + // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) hashTable.get(unsafeKey) } } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 + if (binaryMap != null) { + // This could happen when a cached broadcast object need to be dumped into disk to free memory + out.writeInt(binaryMap.numElements()) + + var buffer = new Array[Byte](64) + def write(addr: MemoryLocation, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) + } + Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, + buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyAddress, loc.getKeyLength) + write(loc.getValueAddress, loc.getValueLength) + } + + } else { + assert(hashTable != null) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 } - out.write(values(i).getBytes) - i += 1 } } } @@ -261,24 +330,25 @@ private[joins] final class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory - val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - - // Dummy shuffle memory manager which always grants all memory allocation requests. - // We use this because it doesn't make sense count shared broadcast variables' memory usage - // towards individual tasks' quotas. In the future, we should devise a better way of handling - // this. - val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) { - override def tryToAcquire(numBytes: Long): Long = numBytes - override def release(numBytes: Long): Unit = {} - } - - val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - .getSizeAsBytes("spark.buffer.pageSize", "64m") + // TODO(josh): This needs to be revisited before we merge this patch; making this change now + // so that tests compile: + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) + + // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit + // during code review binaryMap = new BytesToBytesMap( taskMemoryManager, - shuffleMemoryManager, - nKeys * 2, // reduce hash collision + (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) var i = 0 @@ -287,20 +357,21 @@ private[joins] final class UnsafeHashedRelation( while (i < nKeys) { val keySize = in.readInt() val valuesSize = in.readInt() - if (keySize > keyBuffer.size) { + if (keySize > keyBuffer.length) { keyBuffer = new Array[Byte](keySize) } in.readFully(keyBuffer, 0, keySize) - if (valuesSize > valuesBuffer.size) { + if (valuesSize > valuesBuffer.length) { valuesBuffer = new Array[Byte](valuesSize) } in.readFully(valuesBuffer, 0, valuesSize) // put it into binary map - val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, - valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + val putSuceeded = loc.putNewKey( + keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { throw new IOException("Could not allocate memory to grow BytesToBytesMap") } @@ -313,14 +384,17 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { + // Use a Java hash table here because unsafe maps expect fixed size records val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 4443455ef11f..efa7b49410ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -17,24 +17,27 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys * for hash join. */ -@DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = left.output @@ -52,13 +55,21 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numRightRows += 1 + row.copy() + }.collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { + numLeftRows += 1 var i = 0 var matched = false @@ -69,6 +80,9 @@ case class LeftSemiJoinBNL( } i += 1 } + if (matched) { + numOutputRows += 1 + } matched }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9..bf3b05be981f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Build the right table's join keys into a HashSet, and iteratively go through the left * table, to find the if join keys are in the Hash set. */ -@DeveloperApi case class LeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -37,19 +35,28 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet) + val hashSet = buildKeyHashSet(buildIter, numRightRows) + hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation) + val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) + hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala deleted file mode 100644 index fc6efe87bceb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) - extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala deleted file mode 100644 index eee8ad800f98..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class ShuffledHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - joinType match { - case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) - }) - - case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) - }) - - case FullOuter => - // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) - } - - case x => - throw new IllegalArgumentException( - s"ShuffledHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 41be78afd37e..4bf7b521c77d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,27 +17,29 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** - * :: DeveloperApi :: * Performs an sort merge join of two child relations. */ -@DeveloperApi case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -46,124 +48,275 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) + } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { - // Mutable per row objects. + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) + + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + numLeftRows, + RowIterator.fromScala(rightIter), + numRightRows + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(schema) + + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 } - joinedRow - } else { - // no more result - throw new NoSuchElementException } - } - - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + numOutputRows += 1 + true } else { - leftElement = null + false } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } else { - rightElement = null - } - } + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + numStreamedRows: LongSQLMetric, + bufferedIter: RowIterator, + numBufferedRows: LongSQLMetric) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + + // Initialization (note: do _not_ want to advance streamed here). + advancedBufferedToRowWithNullFreeJoinKey() - private def initialize() = { - fetchLeft() - fetchRight() + // --- Public methods --------------------------------------------------------------------------- + + def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (!advancedStreamed()) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + bufferedMatches.clear() + if (bufferedRow != null && !streamedRowKey.anyNull) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input then we always return true + true } } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + numStreamedRows += 1 + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + numBufferedRows += 1 + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + bufferedMatches.clear() + do { + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 000000000000..efaa69c1d322 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.util.collection.BitSet + +/** + * Performs an sort merge outer join of two child relations. + */ +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } + + override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + // there are null rows in both streams, so there is no order + case FullOuter => Nil + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. + keys.map(SortOrder(_, Ascending)) + } + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) + + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) + + override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) + + joinType match { + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + numLeftRows, + bufferedIter = RowIterator.fromScala(rightIter), + numRightRows + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + numRightRows, + bufferedIter = RowIterator.fromScala(leftIter), + numLeftRows + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + numLeftRows, + rightIter = RowIterator.fromScala(rightIter), + numRightRows, + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} + +/** + * An iterator for outputting rows in left outer join. + */ +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) +} + +/** + * An iterator for outputting rows in right outer join. + */ +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially + assert(smjScanner.getBufferedMatches.length == 0) + + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 + if (smjScanner.findNextOuterJoinRows()) { + setStreamSideOutput(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) + } else { + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) + } + } + true + } else { + // Stream has been exhausted + false + } + } + + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + numLeftRows: LongSQLMetric, + rightIter: RowIterator, + numRightRows: LongSQLMetric, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ + + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + numLeftRows += 1 + true + } else { + leftRow = null + leftRowKey = null + false + } + } + + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + numRightRows += 1 + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false + } + } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() + + override def advanceNext(): Boolean = { + val r = smjScanner.advanceNext() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala index 7f2ab1765b28..134376628ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi - /** - * :: DeveloperApi :: * Physical execution operators for join operations. */ package object joins { - @DeveloperApi sealed abstract class BuildSide - @DeveloperApi case object BuildRight extends BuildSide - @DeveloperApi case object BuildLeft extends BuildSide } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala new file mode 100644 index 000000000000..3dcef9409564 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} + +/** + * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of + * `buildSide`. The actual work of this node is defined in [[HashJoinNode]]. + */ +case class BinaryHashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) + extends BinaryLocalNode(conf) with HashJoinNode { + + protected override val (streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (right, rightKeys) + case BuildRight => (left, leftKeys) + } + + private val (buildNode, buildKeys) = buildSide match { + case BuildLeft => (left, leftKeys) + case BuildRight => (right, rightKeys) + } + + override def output: Seq[Attribute] = left.output ++ right.output + + private def buildSideKeyGenerator: Projection = { + // We are expecting the data types of buildKeys and streamedKeys are the same. + assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) + UnsafeProjection.create(buildKeys, buildNode.output) + } + + protected override def doOpen(): Unit = { + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + // We have built the HashedRelation. So, close buildNode. + buildNode.close() + + streamedNode.open() + // Set the HashedRelation used by the HashJoinNode. + withHashedRelation(hashedRelation) + } + + override def close(): Unit = { + // Please note that we do not need to call the close method of our buildNode because + // it has been called in this.open. + streamedNode.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala new file mode 100644 index 000000000000..cd1c86516ec5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} + +/** + * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast + * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]]. + */ +case class BroadcastHashJoinNode( + conf: SQLConf, + streamedKeys: Seq[Expression], + streamedNode: LocalNode, + buildSide: BuildSide, + buildOutput: Seq[Attribute], + hashedRelation: Broadcast[HashedRelation]) + extends UnaryLocalNode(conf) with HashJoinNode { + + override val child = streamedNode + + // Because we do not pass in the buildNode, we take the output of buildNode to + // create the inputSet properly. + override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput) + + override def output: Seq[Attribute] = buildSide match { + case BuildRight => streamedNode.output ++ buildOutput + case BuildLeft => buildOutput ++ streamedNode.output + } + + protected override def doOpen(): Unit = { + streamedNode.open() + // Set the HashedRelation used by the HashJoinNode. + withHashedRelation(hashedRelation.value) + } + + override def close(): Unit = { + streamedNode.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala new file mode 100644 index 000000000000..b31c5a863832 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} + +case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToSafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToSafe = FromUnsafeProjection(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToSafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala new file mode 100644 index 000000000000..de2f4e661ab4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} + +case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToUnsafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToUnsafe = UnsafeProjection.create(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToUnsafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala new file mode 100644 index 000000000000..85111bd6d1c9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -0,0 +1,60 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ + +case class ExpandNode( + conf: SQLConf, + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LocalNode) extends UnaryLocalNode(conf) { + + assert(projections.size > 0) + + private[this] var result: InternalRow = _ + private[this] var idx: Int = _ + private[this] var input: InternalRow = _ + private[this] var groups: Array[Projection] = _ + + override def open(): Unit = { + child.open() + groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray + idx = groups.length + } + + override def next(): Boolean = { + if (idx >= groups.length) { + if (child.next()) { + input = child.fetch() + idx = 0 + } else { + return false + } + } + result = groups(idx)(input) + idx += 1 + true + } + + override def fetch(): InternalRow = result + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala new file mode 100644 index 000000000000..dd1113b6726c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + + +case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) + extends UnaryLocalNode(conf) { + + private[this] var predicate: (InternalRow) => Boolean = _ + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = { + child.open() + predicate = GeneratePredicate.generate(condition, child.output) + } + + override def next(): Boolean = { + var found = false + while (!found && child.next()) { + found = predicate.apply(child.fetch()) + } + found + } + + override def fetch(): InternalRow = child.fetch() + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala new file mode 100644 index 000000000000..fd7948ffa9a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -0,0 +1,111 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins._ + +/** + * An abstract node for sharing common functionality among different implementations of + * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]]. + * + * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. + */ +trait HashJoinNode { + + self: LocalNode => + + protected def streamedKeys: Seq[Expression] + protected def streamedNode: LocalNode + protected def buildSide: BuildSide + + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ + private[this] var currentMatchPosition: Int = -1 + + private[this] var joinRow: JoinedRow = _ + private[this] var resultProjection: (InternalRow) => InternalRow = _ + + private[this] var hashed: HashedRelation = _ + private[this] var joinKeys: Projection = _ + + private def streamSideKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedNode.output) + + /** + * Sets the HashedRelation used by this node. This method needs to be called after + * before the first `next` gets called. + */ + protected def withHashedRelation(hashedRelation: HashedRelation): Unit = { + hashed = hashedRelation + } + + /** + * Custom open implementation to be overridden by subclasses. + */ + protected def doOpen(): Unit + + override def open(): Unit = { + doOpen() + joinRow = new JoinedRow + resultProjection = UnsafeProjection.create(schema) + joinKeys = streamSideKeyGenerator + } + + override def next(): Boolean = { + currentMatchPosition += 1 + if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { + fetchNextMatch() + } else { + true + } + } + + /** + * Populate `currentHashMatches` with build-side rows matching the next streamed row. + * @return whether matches are found such that subsequent calls to `fetch` are valid. + */ + private def fetchNextMatch(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamedNode.next()) { + currentStreamedRow = streamedNode.fetch() + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashed.get(key) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + + override def fetch(): InternalRow = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + resultProjection(ret) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 000000000000..740d485f8d9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) + extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala new file mode 100644 index 000000000000..401b10a5ed30 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + + +case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var count = 0 + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = child.open() + + override def close(): Unit = child.close() + + override def fetch(): InternalRow = child.fetch() + + override def next(): Boolean = { + if (count < limit) { + count += 1 + child.next() + } else { + false + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala new file mode 100644 index 000000000000..6a882c9234df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -0,0 +1,190 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.sql.{SQLConf, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.types.StructType + +/** + * A local physical operator, in the form of an iterator. + * + * Before consuming the iterator, open function must be called. + * After consuming the iterator, close function must be called. + */ +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { + + private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") + + /** + * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume + * any input data. + * + * Implementations of this must also call the `prepare()` function of its children. + */ + def prepare(): Unit = children.foreach(_.prepare()) + + /** + * Initializes the iterator state. Must be called before calling `next()`. + * + * Implementations of this must also call the `open()` function of its children. + */ + def open(): Unit + + /** + * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. + */ + def next(): Boolean + + /** + * Returns the current tuple. + */ + def fetch(): InternalRow + + /** + * Closes the iterator and releases all resources. It should be idempotent. + * + * Implementations of this must also call the `close()` function of its children. + */ + def close(): Unit + + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + + /** + * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. + */ + final def collect(): Seq[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) + val result = new scala.collection.mutable.ArrayBuffer[Row] + open() + try { + while (next()) { + result += converter.apply(fetch()).asInstanceOf[Row] + } + } finally { + close() + } + result + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } + + protected def newPredicate( + expression: Expression, + inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } + } +} + + +abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { + override def children: Seq[LocalNode] = Seq.empty +} + + +abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def child: LocalNode + + override def children: Seq[LocalNode] = Seq(child) +} + +abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} + +/** + * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. + */ +private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { + private var nextRow: InternalRow = _ + + override def hasNext: Boolean = { + if (nextRow == null) { + val res = localNode.next() + if (res) { + nextRow = localNode.fetch() + } + res + } else { + true + } + } + + override def next(): InternalRow = { + if (hasNext) { + val res = nextRow + nextRow = null + res + } else { + throw new NoSuchElementException + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala new file mode 100644 index 000000000000..7321fc66b4dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + +case class NestedLoopJoinNode( + conf: SQLConf, + left: LocalNode, + right: LocalNode, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"NestedLoopJoin should not take $x as the JoinType") + } + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + private[this] var currentRow: InternalRow = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + override def open(): Unit = { + val (streamed, build) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + build.open() + val buildRelation = new CompactBuffer[InternalRow] + while (build.next()) { + buildRelation += build.fetch().copy() + } + build.close() + + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + val joinedRow = new JoinedRow + val matchedBuildTuples = new BitSet(buildRelation.size) + val resultProj = genResultProjection + streamed.open() + + // streamedRowMatches also contains null rows if using outer join + val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => + val matchedRows = new CompactBuffer[InternalRow] + + var i = 0 + var streamRowMatched = false + + // Scan the build relation to look for matches for each streamed row + while (i < buildRelation.size) { + val buildRow = buildRelation(i) + buildSide match { + case BuildRight => joinedRow(streamedRow, buildRow) + case BuildLeft => joinedRow(buildRow, streamedRow) + } + if (boundCondition(joinedRow)) { + matchedRows += resultProj(joinedRow).copy() + streamRowMatched = true + matchedBuildTuples.set(i) + } + i += 1 + } + + // If this row had no matches and we're using outer join, join it with the null rows + if (!streamRowMatched) { + (joinType, buildSide) match { + case (LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } + } + + matchedRows.iterator + } + + // If we're using outer join, find rows on the build side that didn't match anything + // and join them with the null row + lazy val unmatchedBuildRows: Iterator[InternalRow] = { + var i = 0 + buildRelation.filter { row => + val r = !matchedBuildTuples.get(i) + i += 1 + r + }.iterator + } + iterator = (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } + case (LeftOuter | FullOuter, BuildLeft) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } + case _ => streamedRowMatches + } + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala new file mode 100644 index 000000000000..11529d6dd9b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -0,0 +1,44 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} + + +case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) + extends UnaryLocalNode(conf) { + + private[this] var project: UnsafeProjection = _ + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override def open(): Unit = { + project = UnsafeProjection.create(projectList, child.output) + child.open() + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = { + project.apply(child.fetch()) + } + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 000000000000..793700803f21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + + +/** + * Sample the dataset. + * + * @param conf the SQLConf + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + conf: SQLConf, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val sampler = + if (withReplacement) { + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) + } + sampler.setSeed(seed) + iterator = sampler.sample(child.asIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala new file mode 100644 index 000000000000..b8467f6ae58e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * An operator that scans some local data collection in the form of Scala Seq. + */ +case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) + extends LeafLocalNode(conf) { + + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + iterator = data.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + // Do nothing + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 000000000000..ae672fbca8d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + conf: SQLConf, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var projection: Option[Projection] = _ + private[this] var ord: InterpretedOrdering = _ + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.toArray.sorted(ord).iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala new file mode 100644 index 000000000000..0f2b8303e737 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -0,0 +1,73 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { + + override def output: Seq[Attribute] = children.head.output + + private[this] var currentChild: LocalNode = _ + + private[this] var nextChildIndex: Int = _ + + override def open(): Unit = { + currentChild = children.head + currentChild.open() + nextChildIndex = 1 + } + + private def advanceToNextChild(): Boolean = { + var found = false + var exit = false + while (!exit && !found) { + if (currentChild != null) { + currentChild.close() + } + if (nextChildIndex >= children.size) { + found = false + exit = true + } else { + currentChild = children(nextChildIndex) + nextChildIndex += 1 + currentChild.open() + found = currentChild.next() + } + } + found + } + + override def close(): Unit = { + if (currentChild != null) { + currentChild.close() + } + } + + override def fetch(): InternalRow = currentChild.fetch() + + override def next(): Boolean = { + if (currentChild.next()) { + true + } else { + advanceToNextChild() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 000000000000..2708219ad348 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala new file mode 100644 index 000000000000..6c0f6f8a52dc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -0,0 +1,168 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.util.Utils +import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + * + * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. + */ +private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( + name: String, val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), true) { + + def reset(): Unit = { + this.value = param.zero + } +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + + /** + * A function that defines how we aggregate the final accumulator results among all tasks, + * and represent it in string for a SQL physical operator. + */ + val stringValue: Seq[T] => String + + def zero: R +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricValue[T] extends Serializable { + + def value: T + + override def toString: String = value.toString +} + +/** + * A wrapper of Long to avoid boxing and unboxing when using Accumulator + */ +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { + + def add(incr: Long): LongSQLMetricValue = { + _value += incr + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Long = _value +} + +/** + * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam) + extends SQLMetric[LongSQLMetricValue, Long](name, param) { + + override def +=(term: Long): Unit = { + localValue.add(term) + } + + override def add(term: Long): Unit = { + localValue.add(term) + } +} + +private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long) + extends SQLMetricParam[LongSQLMetricValue, Long] { + + override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) + + override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero + + override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) +} + +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + +private[sql] object SQLMetrics { + + private def createLongMetric( + sc: SparkContext, + name: String, + param: LongSQLMetricParam): LongSQLMetric = { + val acc = new LongSQLMetric(name, param) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } + + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + createLongMetric(sc, name, LongSQLMetricParam) + } + + /** + * Create a metric to report the size information (including total, min, med, max) like data size, + * spill size, etc. + */ + def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + // The final result of this metric in physical operator UI may looks like: + // data size total (min, med, max): + // 100GB (100MB, 1GB, 10GB) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] + } + + /** + * A metric that its value will be ignored. Use this one when we need a metric parameter but don't + * care about the value. + */ + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8f1314..c912734bba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,8 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala similarity index 79% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index aade2e769ccd..defcec95fb55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -20,23 +20,23 @@ package org.apache.spark.sql.execution import java.io.OutputStream import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -66,7 +66,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan @@ -119,6 +119,17 @@ object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + def takeAndServe(df: DataFrame, n: Int): Int = { + registerPicklers() + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ @@ -182,7 +193,7 @@ object EvaluatePython { case (c: Double, DoubleType) => c - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) case (c: Int, DateType) => c @@ -197,14 +208,15 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) + new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keys = c.keysIterator.map(fromJava(_, keyType)).toArray - val values = c.valuesIterator.map(fromJava(_, valueType)).toArray + val keyValues = c.asScala.toSeq + val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray + val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => @@ -310,10 +322,8 @@ object EvaluatePython { } /** - * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ -@DeveloperApi case class EvaluatePython( udf: PythonUDF, child: LogicalPlan, @@ -327,62 +337,76 @@ case class EvaluatePython( } /** - * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -@DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 29f3beb3cb3c..5f8fc2de8b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -17,22 +17,20 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule /** - * :: DeveloperApi :: * Converts Java-object-based rows into [[UnsafeRow]]s. */ -@DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") - override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false override def canProcessSafeRows: Boolean = true @@ -45,12 +43,12 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { } /** - * :: DeveloperApi :: * Converts [[UnsafeRow]]s back into Java-object-based rows. */ -@DeveloperApi case class ConvertToSafe(child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = false @@ -97,18 +95,10 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows if all the schema of them are support by UnsafeRow - if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } + // convert everything unsafe rows. + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala deleted file mode 100644 index 92cf328c76cb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.ExternalSorter - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines various sort operators. -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class ExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy(), null))) - val baseIterator = sorter.iterator.map(_._1) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of - * Project Tungsten). - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will - * spill every `frequency` records. - */ -case class TungstenSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryNode { - - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output - val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - child.execute().mapPartitions({ iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } - sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - }, preservesPartitioning = true) - } - -} - -object TungstenSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 9329148aa233..db463029aedf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ private class FreqItemCounter(size: Int) extends Serializable { val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] - /** * Add a new example to the counts if it exists, otherwise deduct the count * from existing items. @@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging { if (baseMap.size < size) { baseMap += key -> count } else { - // TODO: Make this more efficient... A flatMap? - baseMap.retain((k, v) => v > count) - baseMap.transform((k, v) => v - count) + val minCount = baseMap.values.min + val remainder = count - minCount + if (remainder >= 0) { + baseMap += key -> count // something will get kicked out, so we can add this + baseMap.retain((k, v) => v > minCount) + baseMap.transform((k, v) => v - minCount) + } else { + baseMap.transform((k, v) => v - count) + } } } this @@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) }.toArray - val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i, colInfo(i)._2) + val key = row.get(i) thisMap.add(key, 1L) i += 1 } @@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) - val resultRow = InternalRow(justItems : _*) + val justItems = freqItems.map(m => m.baseMap.keys.toArray) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) + new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala new file mode 100644 index 000000000000..49646a99d68c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { + + private val listener = parent.listener + + override def render(request: HttpServletRequest): Seq[Node] = { + val currentTime = System.currentTimeMillis() + val content = listener.synchronized { + val _content = mutable.ListBuffer[Node]() + if (listener.getRunningExecutions.nonEmpty) { + _content ++= + new RunningExecutionTable( + parent, "Running Queries", currentTime, + listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + if (listener.getCompletedExecutions.nonEmpty) { + _content ++= + new CompletedExecutionTable( + parent, "Completed Queries", currentTime, + listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + if (listener.getFailedExecutions.nonEmpty) { + _content ++= + new FailedExecutionTable( + parent, "Failed Queries", currentTime, + listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + _content + } + UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + } +} + +private[ui] abstract class ExecutionTable( + parent: SQLTab, + tableId: String, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData], + showRunningJobs: Boolean, + showSucceededJobs: Boolean, + showFailedJobs: Boolean) { + + protected def baseHeader: Seq[String] = Seq( + "ID", + "Description", + "Submitted", + "Duration") + + protected def header: Seq[String] + + protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + val submissionTime = executionUIData.submissionTime + val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime + + val runningJobs = executionUIData.runningJobs.map { jobId => + {jobId.toString}
        + } + val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId => + {jobId.toString}
        + } + val failedJobs = executionUIData.failedJobs.sorted.map { jobId => + {jobId.toString}
        + } + + + {executionUIData.executionId.toString} + + + {descriptionCell(executionUIData)} + + + {UIUtils.formatDate(submissionTime)} + + + {UIUtils.formatDuration(duration)} + + {if (showRunningJobs) { + + {runningJobs} + + }} + {if (showSucceededJobs) { + + {succeededJobs} + + }} + {if (showFailedJobs) { + + {failedJobs} + + }} + {detailCell(executionUIData.physicalPlanDescription)} + + } + + private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + val details = if (execution.details.nonEmpty) { + + +details + ++ + + } else { + Nil + } + + val desc = { + {execution.description} + } + +
        {desc} {details}
        + } + + private def detailCell(physicalPlan: String): Seq[Node] = { + val isMultiline = physicalPlan.indexOf('\n') >= 0 + val summary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + physicalPlan.substring(0, physicalPlan.indexOf('\n')) + } else { + physicalPlan + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + {summary}{details} + } + + def toNodeSeq: Seq[Node] = { +
        +

        {tableName}

        + {UIUtils.listingTable[SQLExecutionUIData]( + header, row(currentTime, _), executionUIDatas, id = Some(tableId))} +
        + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + + private def executionURL(executionID: Long): String = + s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" +} + +private[ui] class RunningExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "running-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = true, + showSucceededJobs = true, + showFailedJobs = true) { + + override protected def header: Seq[String] = + baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs", "Detail") +} + +private[ui] class CompletedExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "completed-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = false, + showSucceededJobs = true, + showFailedJobs = false) { + + override protected def header: Seq[String] = baseHeader ++ Seq("Jobs", "Detail") +} + +private[ui] class FailedExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "failed-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = false, + showSucceededJobs = true, + showFailedJobs = true) { + + override protected def header: Seq[String] = + baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs", "Detail") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala new file mode 100644 index 000000000000..c74ad4040699 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { + + private val listener = parent.listener + + override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { + val parameterExecutionId = request.getParameter("id") + require(parameterExecutionId != null && parameterExecutionId.nonEmpty, + "Missing execution id parameter") + + val executionId = parameterExecutionId.toLong + val content = listener.getExecution(executionId).map { executionUIData => + val currentTime = System.currentTimeMillis() + val duration = + executionUIData.completionTime.getOrElse(currentTime) - executionUIData.submissionTime + + val summary = +
        +
          +
        • + Submitted Time: {UIUtils.formatDate(executionUIData.submissionTime)} +
        • +
        • + Duration: {UIUtils.formatDuration(duration)} +
        • + {if (executionUIData.runningJobs.nonEmpty) { +
        • + Running Jobs: + {executionUIData.runningJobs.sorted.map { jobId => + {jobId.toString}  + }} +
        • + }} + {if (executionUIData.succeededJobs.nonEmpty) { +
        • + Succeeded Jobs: + {executionUIData.succeededJobs.sorted.map { jobId => + {jobId.toString}  + }} +
        • + }} + {if (executionUIData.failedJobs.nonEmpty) { +
        • + Failed Jobs: + {executionUIData.failedJobs.sorted.map { jobId => + {jobId.toString}  + }} +
        • + }} +
        +
        + + val metrics = listener.getExecutionMetrics(executionId) + + summary ++ + planVisualization(metrics, executionUIData.physicalPlanGraph) ++ + physicalPlanDescription(executionUIData.physicalPlanDescription) + }.getOrElse { +
        No information to display for Plan {executionId}
        + } + + UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + } + + + private def planVisualizationResources: Seq[Node] = { + // scalastyle:off + + + + + + // scalastyle:on + } + + private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + val metadata = graph.nodes.flatMap { node => + val nodeId = s"plan-meta-data-${node.id}" +
        {node.desc}
        + } + +
        +
        + + {planVisualizationResources} + +
        + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + + private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = { +
        + + + Details + +
        + + +
        + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala new file mode 100644 index 000000000000..e19a1e3e5851 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import scala.collection.mutable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} +import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} + +private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { + + private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) + + private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() + + // Old data in the following fields must be removed in "trimExecutionsIfNecessary". + // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data + private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() + + /** + * Maintain the relation between job id and execution id so that we can get the execution id in + * the "onJobEnd" method. + */ + private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() + + private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() + + private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { + _executionIdToData.toMap + } + + def jobIdToExecutionId: Map[Long, Long] = synchronized { + _jobIdToExecutionId.toMap + } + + def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { + _stageIdToStageMetrics.toMap + } + + private def trimExecutionsIfNecessary( + executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { + if (executions.size > retainedExecutions) { + val toRemove = math.max(retainedExecutions / 10, 1) + executions.take(toRemove).foreach { execution => + for (executionUIData <- _executionIdToData.remove(execution.executionId)) { + for (jobId <- executionUIData.jobs.keys) { + _jobIdToExecutionId.remove(jobId) + } + for (stageId <- executionUIData.stages) { + _stageIdToStageMetrics.remove(stageId) + } + } + } + executions.trimStart(toRemove) + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + val executionId = executionIdString.toLong + val jobId = jobStart.jobId + val stageIds = jobStart.stageIds + + synchronized { + activeExecutions.get(executionId).foreach { executionUIData => + executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING + executionUIData.stages ++= stageIds + stageIds.foreach(stageId => + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) + _jobIdToExecutionId(jobId) = executionId + } + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + val jobId = jobEnd.jobId + for (executionId <- _jobIdToExecutionId.get(jobId); + executionUIData <- _executionIdToData.get(executionId)) { + jobEnd.jobResult match { + case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED + case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED + } + if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { + // We are the last job of this execution, so mark the execution as finished. Note that + // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` + // since these are called on different threads. + markExecutionFinished(executionId) + } + } + } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) + } + } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val stageAttemptId = stageSubmitted.stageInfo.attemptId + // Always override metrics for old stage attempt + if (_stageIdToStageMetrics.contains(stageId)) { + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } else { + // If a stage belongs to some SQL execution, its stageId will be put in "onJobStart". + // Since "_stageIdToStageMetrics" doesn't contain it, it must not belong to any SQL execution. + // So we can ignore it. Otherwise, this may lead to memory leaks (SPARK-11126). + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics.accumulatorUpdates(), + finishTask = true) + } + + /** + * Update the accumulator values of a task with the latest metrics for this task. This is called + * every time we receive an executor heartbeat or when a task finishes. + */ + protected def updateTaskAccumulatorValues( + taskId: Long, + stageId: Int, + stageAttemptID: Int, + accumulatorUpdates: Map[Long, Any], + finishTask: Boolean): Unit = { + + _stageIdToStageMetrics.get(stageId) match { + case Some(stageMetrics) => + if (stageAttemptID < stageMetrics.stageAttemptId) { + // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. + } else if (stageAttemptID > stageMetrics.stageAttemptId) { + logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + + s"what we have seen (${stageMetrics.stageAttemptId})") + } else { + // TODO We don't know the attemptId. Currently, what we can do is overriding the + // accumulator updates. However, if there are two same task are running, such as + // speculation, the accumulator updates will be overriding by different task attempts, + // the results will be weird. + stageMetrics.taskIdToMetricUpdates.get(taskId) match { + case Some(taskMetrics) => + if (finishTask) { + taskMetrics.finished = true + taskMetrics.accumulatorUpdates = accumulatorUpdates + } else if (!taskMetrics.finished) { + taskMetrics.accumulatorUpdates = accumulatorUpdates + } else { + // If a task is finished, we should not override with accumulator updates from + // heartbeat reports + } + case None => + // TODO Now just set attemptId to 0. Should fix here when we can get the attempt + // id from SparkListenerExecutorMetricsUpdate + stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( + attemptId = 0, finished = finishTask, accumulatorUpdates) + } + } + case None => + // This execution and its stage have been dropped + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } + } + } + case _ => // Ignore + } + + private def markExecutionFinished(executionId: Long): Unit = { + activeExecutions.remove(executionId).foreach { executionUIData => + if (executionUIData.isFailed) { + failedExecutions += executionUIData + trimExecutionsIfNecessary(failedExecutions) + } else { + completedExecutions += executionUIData + trimExecutionsIfNecessary(completedExecutions) + } + } + } + + def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { + activeExecutions.values.toSeq + } + + def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { + failedExecutions + } + + def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { + completedExecutions + } + + def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { + _executionIdToData.get(executionId) + } + + /** + * Get all accumulator updates from all tasks which belong to this execution and merge them. + */ + def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized { + _executionIdToData.get(executionId) match { + case Some(executionUIData) => + val accumulatorUpdates = { + for (stageId <- executionUIData.stages; + stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; + taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; + accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { + accumulatorUpdate + } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => + executionUIData.accumulatorMetrics(accumulatorId).metricParam) + case None => + // This execution has been dropped + Map.empty + } + } + + private def mergeAccumulatorUpdates( + accumulatorUpdates: Seq[(Long, Any)], + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = { + accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => + val param = paramFunc(accumulatorId) + (accumulatorId, + param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value))) + } + } + +} + +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + +/** + * Represent all necessary data for an execution that will be used in Web UI. + */ +private[ui] class SQLExecutionUIData( + val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val physicalPlanGraph: SparkPlanGraph, + val accumulatorMetrics: Map[Long, SQLPlanMetric], + val submissionTime: Long, + var completionTime: Option[Long] = None, + val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty, + val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer()) { + + /** + * Return whether there are running jobs in this execution. + */ + def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) + + /** + * Return whether there are any failed jobs in this execution. + */ + def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) + + def runningJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq + + def succeededJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq + + def failedJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq +} + +/** + * Represent a metric in a SQLPlan. + * + * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator + * updates for each task. So that if a task is retried, we can simply override the old updates with + * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use + * "AccumulatorParam" to get the aggregation value. + */ +private[ui] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) + +/** + * Store all accumulatorUpdates for all tasks in a Spark stage. + */ +private[ui] class SQLStageMetrics( + val stageAttemptId: Long, + val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) + +/** + * Store all accumulatorUpdates for a Spark task. + */ +private[ui] class SQLTaskMetrics( + val attemptId: Long, // TODO not used yet + var finished: Boolean, + var accumulatorUpdates: Map[Long, Any]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala similarity index 58% rename from sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 0d4e30f29255..4f50b2ecdc8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -15,22 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution.ui import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.ui.{SparkUI, SparkUITab} -class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { +private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) + extends SparkUITab(sparkUI, "SQL") with Logging { - private val functionRegistry = sqlContext.functionRegistry + val parent = sparkUI - def register( - name: String, - func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { - def builder(children: Seq[Expression]) = ScalaUDAF(children, func) - functionRegistry.registerFunction(name, builder) - func - } + attachPage(new AllExecutionsPage(this)) + attachPage(new ExecutionPage(this)) + parent.attachTab(this) + + parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") +} + +private[sql] object SQLTab { + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala new file mode 100644 index 000000000000..3a6eff939982 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * A graph used for storing information of an executionPlan of DataFrame. + * + * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the + * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. + */ +private[ui] case class SparkPlanGraph( + nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { + + def makeDotFile(metrics: Map[Long, String]): String = { + val dotFile = new StringBuilder + dotFile.append("digraph G {\n") + nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) + edges.foreach(edge => dotFile.append(edge.makeDotEdge + "\n")) + dotFile.append("}") + dotFile.toString() + } +} + +private[sql] object SparkPlanGraph { + + /** + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { + val nodeIdGenerator = new AtomicLong(0) + val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + new SparkPlanGraph(nodes, edges) + } + + private def buildSparkPlanGraphNode( + planInfo: SparkPlanInfo, + nodeIdGenerator: AtomicLong, + nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + + nodes += node + val childrenNodes = planInfo.children.map( + child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) + for (child <- childrenNodes) { + edges += SparkPlanGraphEdge(child.id, node.id) + } + node + } +} + +/** + * Represent a node in the SparkPlan tree, along with its metrics. + * + * @param id generated by "SparkPlanGraph". There is no duplicate id in a graph + * @param name the name of this SparkPlan node + * @param metrics metrics that this SparkPlan node will track + */ +private[ui] case class SparkPlanGraphNode( + id: Long, + name: String, + desc: String, + metadata: Map[String, String], + metrics: Seq[SQLPlanMetric]) { + + def makeDotNode(metricsValue: Map[Long, String]): String = { + val builder = new mutable.StringBuilder(name) + + val values = for { + metric <- metrics + value <- metricsValue.get(metric.accumulatorId) + } yield { + metric.name + ": " + value + } + + if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\\n \\n" + builder ++= values.mkString("\\n") + } + + s""" $id [label="${builder.toString()}"];""" + } +} + +/** + * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * node id. + */ +private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { + + def makeDotEdge: String = s""" $fromId->$toId;\n""" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 000000000000..65117d582475 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} + +/** + * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] + * operations to take all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def finish(r: Int): Int = r + * }.toColumn() + * + * val ds: Dataset[Data] = ... + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam I The input type for the aggregation. + * @tparam B The type of the intermediate value of the reduction. + * @tparam O The type of the final output result. + * + * @since 1.6.0 + */ +abstract class Aggregator[-I, B, O] extends Serializable { + + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ + def zero: B + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + * @since 1.6.0 + */ + def reduce(b: B, a: I): B + + /** + * Merge two intermediate values. + * @since 1.6.0 + */ + def merge(b1: B, b2: B): B + + /** + * Transform the output of the reduction. + * @since 1.6.0 + */ + def finish(reduction: B): O + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * operations. + * @since 1.6.0 + */ + def toColumn( + implicit bEncoder: Encoder[B], + cEncoder: Encoder[O]): TypedColumn[I, O] = { + val expr = + new AggregateExpression( + TypedAggregateExpression(this), + Complete, + false) + + new TypedColumn[I, O](expr, encoderFor[O]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c3d224629702..9397fb84105a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ /** @@ -139,37 +140,7 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") - } - new Column(windowExpr) + val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) + new Column(WindowExpression(aggregate.expr, spec)) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 278dd438fab4..11dbf391cff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * The abstract class for implementing user-defined aggregate functions. + * The base class for implementing user-defined aggregate functions (UDAF). */ @Experimental abstract class UserDefinedAggregateFunction extends Serializable { @@ -64,22 +66,35 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. */ - def returnDataType: DataType + def dataType: DataType - /** Indicates if this function is deterministic. */ + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ def deterministic: Boolean /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer - * still store initial values. + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. */ def initialize(buffer: MutableAggregationBuffer): Unit - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ def update(buffer: MutableAggregationBuffer, input: Row): Unit - /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** @@ -87,14 +102,43 @@ abstract class UserDefinedAggregateFunction extends Serializable { * aggregation buffer. */ def evaluate(buffer: Row): Any + + /** + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } + + /** + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def distinct(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = true) + Column(aggregateExpression) + } } /** * :: Experimental :: * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. */ @Experimental -trait MutableAggregationBuffer extends Row { +abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 35958299076c..65733dcf83e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,11 +24,33 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have + * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate + * "bridge" methods due to the use of covariant return types. + * + * {{{ + * // In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * + * // In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); + * }}} + * + * This allows us to use the same functions both in typed [[Dataset]] operations and untyped + * [[DataFrame]] operations when the return type for a given function is statically known. + */ +private[sql] abstract class LegacyFunctions { + def count(columnName: String): Column +} + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -43,15 +65,21 @@ import org.apache.spark.util.Utils * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname collection_funcs Collection functions - * @groupname Ungrouped Support functions for DataFrames. + * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ @Experimental // scalastyle:off -object functions { +object functions extends LegacyFunctions { // scalastyle:on - private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + private def withExpr(expr: Expression): Column = Column(expr) + + private def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + Column(func.toAggregateExpression(isDistinct)) + } /** * Returns a [[Column]] based on the given column name. @@ -128,7 +156,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -144,7 +174,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd, 0, 0) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -162,7 +194,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = Average(e.expr) + def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -172,16 +204,78 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(columnName: String): Column = collect_list(Column(columnName)) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(columnName: String): Column = collect_set(Column(columnName)) + + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(column1: Column, column2: Column): Column = withAggregateFunction { + Corr(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(columnName1: String, columnName2: String): Column = { + corr(Column(columnName1), Column(columnName2)) + } + /** * Aggregate function: returns the number of items in a group. * * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = e.expr match { - // Turn count(*) into count(1) - case s: Star => Count(Literal(1)) - case _ => Count(e.expr) + def count(e: Column): Column = withAggregateFunction { + e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } } /** @@ -190,7 +284,8 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. @@ -199,8 +294,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = - CountDistinct((expr +: exprs).map(_.expr)) + def countDistinct(expr: Column, exprs: Column*): Column = { + withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -218,7 +314,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = First(e.expr) + def first(e: Column): Column = withAggregateFunction { new First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -228,13 +324,29 @@ object functions { */ def first(columnName: String): Column = first(Column(columnName)) + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = Last(e.expr) + def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -250,7 +362,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = Max(e.expr) + def max(e: Column): Column = withAggregateFunction { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -284,7 +396,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = Min(e.expr) + def min(e: Column): Column = withAggregateFunction { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -294,13 +406,81 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(columnName: String): Column = stddev(Column(columnName)) + + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = Sum(e.expr) + def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -316,7 +496,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -326,10 +506,65 @@ object functions { */ def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def cumeDist(): Column = cume_dist() + /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -339,15 +574,17 @@ object functions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) - } + def cume_dist(): Column = withExpr { new CumeDist } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") + def denseRank(): Column = dense_rank() /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -357,14 +594,10 @@ object functions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) - } + def dense_rank(): Column = withExpr { new DenseRank } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -376,9 +609,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int): Column = { - lag(e, offset, null) - } + def lag(e: Column, offset: Int): Column = lag(e, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -390,9 +621,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int): Column = { - lag(columnName, offset, null) - } + def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -418,8 +647,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -432,9 +661,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int): Column = { - lead(columnName, offset, null) - } + def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -446,9 +673,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int): Column = { - lead(e, offset, null) - } + def lead(e: Column, offset: Int): Column = { lead(e, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -474,8 +699,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -488,9 +713,14 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = { - UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) - } + def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") + def percentRank(): Column = percent_rank() /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -503,11 +733,9 @@ object functions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) - } + def percent_rank(): Column = withExpr { new PercentRank } /** * Window function: returns the rank of rows within a window partition. @@ -522,21 +750,22 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = { - UnresolvedWindowFunction("rank", Nil) - } + def rank(): Column = withExpr { new Rank } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. + */ + @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") + def rowNumber(): Column = row_number() /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) - } + def row_number(): Column = withExpr { RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -548,7 +777,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def abs(e: Column): Column = Abs(e.expr) + def abs(e: Column): Column = withExpr { Abs(e.expr) } /** * Creates a new array column. The input columns must all have the same data type. @@ -557,7 +786,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) } /** * Creates a new array column. The input columns must all have the same data type. @@ -565,6 +794,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } @@ -595,22 +825,45 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. + */ + @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") + def inputFileName(): Column = input_file_name() /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 + */ + def input_file_name(): Column = withExpr { InputFileName() } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. */ - def inputFileName(): Column = InputFileName() + @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") + def isNaN(e: Column): Column = isnan(e) /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 + */ + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 */ - def isNaN(e: Column): Column = IsNaN(e.expr) + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -627,7 +880,24 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -637,7 +907,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) + def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) } /** * Unary minus, i.e. negate the expression. @@ -676,7 +946,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def rand(seed: Long): Column = Rand(seed) + def rand(seed: Long): Column = withExpr { Rand(seed) } /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. @@ -692,7 +962,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def randn(seed: Long): Column = Randn(seed) + def randn(seed: Long): Column = withExpr { Randn(seed) } /** * Generate a column with i.i.d. samples from the standard normal distribution. @@ -702,15 +972,23 @@ object functions { */ def randn(): Column = randn(Utils.random.nextLong) + /** + * @group normal_funcs + * @since 1.4.0 + * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def sparkPartitionId(): Column = spark_partition_id() + /** * Partition ID of the Spark task. * * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = SparkPartitionID() + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -718,7 +996,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def sqrt(e: Column): Column = Sqrt(e.expr) + def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) } /** * Computes the square root of the specified float value. @@ -739,9 +1017,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = { - CreateStruct(cols.map(_.expr)) - } + def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. @@ -749,6 +1025,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } @@ -774,7 +1051,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = { + def when(condition: Column, value: Any): Column = withExpr { CaseWhen(Seq(condition.expr, lit(value).expr)) } @@ -784,7 +1061,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) } /** * Parses the expression string into the column that it represents, similar to @@ -796,7 +1073,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions @@ -809,7 +1086,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(e: Column): Column = Acos(e.expr) + def acos(e: Column): Column = withExpr { Acos(e.expr) } /** * Computes the cosine inverse of the given column; the returned angle is in the range @@ -827,7 +1104,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(e: Column): Column = Asin(e.expr) + def asin(e: Column): Column = withExpr { Asin(e.expr) } /** * Computes the sine inverse of the given column; the returned angle is in the range @@ -844,7 +1121,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(e: Column): Column = Atan(e.expr) + def atan(e: Column): Column = withExpr { Atan(e.expr) } /** * Computes the tangent inverse of the given column. @@ -861,7 +1138,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -898,7 +1175,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -916,7 +1193,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + def atan2(l: Double, r: Column): Column = atan2(lit(l), r) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -934,7 +1211,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(e: Column): Column = Bin(e.expr) + def bin(e: Column): Column = withExpr { Bin(e.expr) } /** * An expression that returns the string representation of the binary value of the given long @@ -951,7 +1228,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(e: Column): Column = Cbrt(e.expr) + def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) } /** * Computes the cube-root of the given column. @@ -967,7 +1244,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = Ceil(e.expr) + def ceil(e: Column): Column = withExpr { Ceil(e.expr) } /** * Computes the ceiling of the given column. @@ -983,8 +1260,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def conv(num: Column, fromBase: Int, toBase: Int): Column = + def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + } /** * Computes the cosine of the given value. @@ -992,7 +1270,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(e: Column): Column = Cos(e.expr) + def cos(e: Column): Column = withExpr { Cos(e.expr) } /** * Computes the cosine of the given column. @@ -1008,7 +1286,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(e: Column): Column = Cosh(e.expr) + def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** * Computes the hyperbolic cosine of the given column. @@ -1024,7 +1302,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(e: Column): Column = Exp(e.expr) + def exp(e: Column): Column = withExpr { Exp(e.expr) } /** * Computes the exponential of the given column. @@ -1040,7 +1318,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(e: Column): Column = Expm1(e.expr) + def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** * Computes the exponential of the given column. @@ -1056,15 +1334,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def factorial(e: Column): Column = Factorial(e.expr) - - /** - * Computes the factorial of the given column. - * - * @group math_funcs - * @since 1.5.0 - */ - def factorial(columnName: String): Column = factorial(Column(columnName)) + def factorial(e: Column): Column = withExpr { Factorial(e.expr) } /** * Computes the floor of the given value. @@ -1072,7 +1342,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = Floor(e.expr) + def floor(e: Column): Column = withExpr { Floor(e.expr) } /** * Computes the floor of the given column. @@ -1090,7 +1360,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = { + def greatest(exprs: Column*): Column = withExpr { require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1108,12 +1378,12 @@ object functions { } /** - * Computes hex value of the given column. - * - * @group math_funcs - * @since 1.5.0 - */ - def hex(column: Column): Column = Hex(column.expr) + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = withExpr { Hex(column.expr) } /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -1122,7 +1392,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = Unhex(column.expr) + def unhex(column: Column): Column = withExpr { Unhex(column.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1130,7 +1400,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1163,7 +1433,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + def hypot(l: Column, r: Double): Column = hypot(l, lit(r)) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1179,7 +1449,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + def hypot(l: Double, r: Column): Column = hypot(lit(l), r) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1197,7 +1467,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = { + def least(exprs: Column*): Column = withExpr { require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1220,7 +1490,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(e: Column): Column = Log(e.expr) + def log(e: Column): Column = withExpr { Log(e.expr) } /** * Computes the natural logarithm of the given column. @@ -1236,7 +1506,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -1252,7 +1522,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(e: Column): Column = Log10(e.expr) + def log10(e: Column): Column = withExpr { Log10(e.expr) } /** * Computes the logarithm of the given value in base 10. @@ -1268,7 +1538,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(e: Column): Column = Log1p(e.expr) + def log1p(e: Column): Column = withExpr { Log1p(e.expr) } /** * Computes the natural logarithm of the given column plus one. @@ -1284,7 +1554,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(expr: Column): Column = Log2(expr.expr) + def log2(expr: Column): Column = withExpr { Log2(expr.expr) } /** * Computes the logarithm of the given value in base 2. @@ -1300,7 +1570,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -1332,7 +1602,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + def pow(l: Column, r: Double): Column = pow(l, lit(r)) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1348,7 +1618,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + def pow(l: Double, r: Column): Column = pow(lit(l), r) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1364,16 +1634,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) - - /** - * Returns the positive value of dividend mod divisor. - * - * @group math_funcs - * @since 1.5.0 - */ - def pmod(dividendColName: String, divisorColName: String): Column = - pmod(Column(dividendColName), Column(divisorColName)) + def pmod(dividend: Column, divisor: Column): Column = withExpr { + Pmod(dividend.expr, divisor.expr) + } /** * Returns the double value that is closest in value to the argument and @@ -1382,7 +1645,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(e: Column): Column = Rint(e.expr) + def rint(e: Column): Column = withExpr { Rint(e.expr) } /** * Returns the double value that is closest in value to the argument and @@ -1399,15 +1662,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = round(e.expr, 0) - - /** - * Returns the value of the given column rounded to 0 decimal places. - * - * @group math_funcs - * @since 1.5.0 - */ - def round(columnName: String): Column = round(Column(columnName), 0) + def round(e: Column): Column = round(e, 0) /** * Round the value of `e` to `scale` decimal places if `scale` >= 0 @@ -1416,25 +1671,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) - - /** - * Round the value of the given column to `scale` decimal places if `scale` >= 0 - * or at integral part when `scale` < 0. - * - * @group math_funcs - * @since 1.5.0 - */ - def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) - - /** - * Shift the the given value numBits left. If the given value is a long value, this function - * will return a long value else it will return an integer value. - * - * @group math_funcs - * @since 1.5.0 - */ - def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** * Shift the the given value numBits left. If the given value is a long value, this function @@ -1443,8 +1680,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftLeft(columnName: String, numBits: Int): Column = - shiftLeft(Column(columnName), numBits) + def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } /** * Shift the the given value numBits right. If the given value is a long value, it will return @@ -1453,17 +1689,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) - - /** - * Unsigned shift the the given value numBits right. If the given value is a long value, - * it will return a long value else it will return an integer value. - * - * @group math_funcs - * @since 1.5.0 - */ - def shiftRightUnsigned(columnName: String, numBits: Int): Column = - shiftRightUnsigned(Column(columnName), numBits) + def shiftRight(e: Column, numBits: Int): Column = withExpr { + ShiftRight(e.expr, lit(numBits).expr) + } /** * Unsigned shift the the given value numBits right. If the given value is a long value, @@ -1472,18 +1700,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRightUnsigned(e: Column, numBits: Int): Column = + def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr { ShiftRightUnsigned(e.expr, lit(numBits).expr) - - /** - * Shift the the given value numBits right. If the given value is a long value, it will return - * a long value else it will return an integer value. - * - * @group math_funcs - * @since 1.5.0 - */ - def shiftRight(columnName: String, numBits: Int): Column = - shiftRight(Column(columnName), numBits) + } /** * Computes the signum of the given value. @@ -1491,7 +1710,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(e: Column): Column = Signum(e.expr) + def signum(e: Column): Column = withExpr { Signum(e.expr) } /** * Computes the signum of the given column. @@ -1507,7 +1726,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(e: Column): Column = Sin(e.expr) + def sin(e: Column): Column = withExpr { Sin(e.expr) } /** * Computes the sine of the given column. @@ -1523,7 +1742,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(e: Column): Column = Sinh(e.expr) + def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** * Computes the hyperbolic sine of the given column. @@ -1539,7 +1758,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(e: Column): Column = Tan(e.expr) + def tan(e: Column): Column = withExpr { Tan(e.expr) } /** * Computes the tangent of the given column. @@ -1555,7 +1774,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(e: Column): Column = Tanh(e.expr) + def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** * Computes the hyperbolic tangent of the given column. @@ -1571,7 +1790,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toDegrees(e: Column): Column = ToDegrees(e.expr) + def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -1587,7 +1806,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toRadians(e: Column): Column = ToRadians(e.expr) + def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -1608,7 +1827,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def md5(e: Column): Column = Md5(e.expr) + def md5(e: Column): Column = withExpr { Md5(e.expr) } /** * Calculates the SHA-1 digest of a binary column and returns the value @@ -1617,7 +1836,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def sha1(e: Column): Column = Sha1(e.expr) + def sha1(e: Column): Column = withExpr { Sha1(e.expr) } /** * Calculates the SHA-2 family of hash functions of a binary column and @@ -1632,7 +1851,7 @@ object functions { def sha2(e: Column, numBits: Int): Column = { require(Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - Sha2(e.expr, lit(numBits).expr) + withExpr { Sha2(e.expr, lit(numBits).expr) } } /** @@ -1642,7 +1861,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def crc32(e: Column): Column = Crc32(e.expr) + def crc32(e: Column): Column = withExpr { Crc32(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // String functions @@ -1655,7 +1874,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = Ascii(e.expr) + def ascii(e: Column): Column = withExpr { Ascii(e.expr) } /** * Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -1664,7 +1883,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = Base64(e.expr) + def base64(e: Column): Column = withExpr { Base64(e.expr) } /** * Concatenates multiple input string columns together into a single string column. @@ -1673,7 +1892,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } /** * Concatenates multiple input string columns together into a single string column, @@ -1683,7 +1902,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat_ws(sep: String, exprs: Column*): Column = { + def concat_ws(sep: String, exprs: Column*): Column = withExpr { ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) } @@ -1695,7 +1914,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + def decode(value: Column, charset: String): Column = withExpr { + Decode(value.expr, lit(charset).expr) + } /** * Computes the first argument into a binary from a string using the provided character set @@ -1705,7 +1926,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + def encode(value: Column, charset: String): Column = withExpr { + Encode(value.expr, lit(charset).expr) + } /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, @@ -1717,7 +1940,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def format_number(x: Column, d: Int): Column = withExpr { + FormatNumber(x.expr, lit(d).expr) + } /** * Formats the arguments in printf-style and returns the result as a string column. @@ -1726,7 +1951,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def format_string(format: String, arguments: Column*): Column = { + def format_string(format: String, arguments: Column*): Column = withExpr { FormatString((lit(format) +: arguments).map(_.expr): _*) } @@ -1739,7 +1964,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def initcap(e: Column): Column = InitCap(e.expr) + def initcap(e: Column): Column = withExpr { InitCap(e.expr) } /** * Locate the position of the first occurrence of substr column in the given string. @@ -1751,7 +1976,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + def instr(str: Column, substring: String): Column = withExpr { + StringInstr(str.expr, lit(substring).expr) + } /** * Computes the length of a given string or binary column. @@ -1759,7 +1986,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def length(e: Column): Column = Length(e.expr) + def length(e: Column): Column = withExpr { Length(e.expr) } /** * Converts a string column to lower case. @@ -1767,14 +1994,14 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def lower(e: Column): Column = Lower(e.expr) + def lower(e: Column): Column = withExpr { Lower(e.expr) } /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr) } /** * Locate the position of the first occurrence of substr. @@ -1784,18 +2011,10 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column): Column = { + def locate(substr: String, str: Column): Column = withExpr { new StringLocate(lit(substr).expr, str.expr) } - /** - * Trim the spaces from left end for the specified string value. - * - * @group string_funcs - * @since 1.5.0 - */ - def ltrim(e: Column): Column = StringTrimLeft(e.expr) - /** * Locate the position of the first occurrence of substr in a string column, after position pos. * @@ -1805,7 +2024,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column, pos: Int): Column = { + def locate(substr: String, str: Column, pos: Int): Column = withExpr { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } @@ -1815,17 +2034,25 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = { + def lpad(str: Column, len: Int, pad: String): Column = withExpr { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + /** * Extract a specific(idx) group identified by a java regex, from the specified string column. * * @group string_funcs * @since 1.5.0 */ - def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } @@ -1835,7 +2062,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) } @@ -1846,7 +2073,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = UnBase64(e.expr) + def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** * Right-padded with pad to a length of len. @@ -1854,7 +2081,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = { + def rpad(str: Column, len: Int, pad: String): Column = withExpr { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1864,7 +2091,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, n: Int): Column = { + def repeat(str: Column, n: Int): Column = withExpr { StringRepeat(str.expr, lit(n).expr) } @@ -1874,9 +2101,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def reverse(str: Column): Column = { - StringReverse(str.expr) - } + def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } /** * Trim the spaces from right end for the specified string value. @@ -1884,7 +2109,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = StringTrimRight(e.expr) + def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** * * Return the soundex code for the specified expression. @@ -1892,7 +2117,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def soundex(e: Column): Column = SoundEx(e.expr) + def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** * Splits str around pattern (pattern is a regular expression). @@ -1901,7 +2126,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = { + def split(str: Column, pattern: String): Column = withExpr { StringSplit(str.expr, lit(pattern).expr) } @@ -1913,8 +2138,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def substring(str: Column, pos: Int, len: Int): Column = + def substring(str: Column, pos: Int, len: Int): Column = withExpr { Substring(str.expr, lit(pos).expr, lit(len).expr) + } /** * Returns the substring from string str before count occurrences of the delimiter delim. @@ -1924,8 +2150,22 @@ object functions { * * @group string_funcs */ - def substring_index(str: Column, delim: String, count: Int): Column = + def substring_index(str: Column, delim: String, count: Int): Column = withExpr { SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + } + + /** + * Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ + def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { + StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + } /** * Trim the spaces from both ends for the specified string column. @@ -1933,7 +2173,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = StringTrim(e.expr) + def trim(e: Column): Column = withExpr { StringTrim(e.expr) } /** * Converts a string column to upper case. @@ -1941,7 +2181,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def upper(e: Column): Column = Upper(e.expr) + def upper(e: Column): Column = withExpr { Upper(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // DateTime functions @@ -1953,8 +2193,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = + def add_months(startDate: Column, numMonths: Int): Column = withExpr { AddMonths(startDate.expr, Literal(numMonths)) + } /** * Returns the current date as a date column. @@ -1962,7 +2203,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_date(): Column = CurrentDate() + def current_date(): Column = withExpr { CurrentDate() } /** * Returns the current timestamp as a timestamp column. @@ -1970,7 +2211,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_timestamp(): Column = CurrentTimestamp() + def current_timestamp(): Column = withExpr { CurrentTimestamp() } /** * Converts a date/timestamp/string to a value of string in the format specified by the date @@ -1985,71 +2226,72 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_format(dateExpr: Column, format: String): Column = + def date_format(dateExpr: Column, format: String): Column = withExpr { DateFormatClass(dateExpr.expr, Literal(format)) + } /** * Returns the date that is `days` days after `start` * @group datetime_funcs * @since 1.5.0 */ - def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) } /** * Returns the date that is `days` days before `start` * @group datetime_funcs * @since 1.5.0 */ - def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) } /** * Returns the number of days from `start` to `end`. * @group datetime_funcs * @since 1.5.0 */ - def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) } /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def year(e: Column): Column = Year(e.expr) + def year(e: Column): Column = withExpr { Year(e.expr) } /** * Extracts the quarter as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def quarter(e: Column): Column = Quarter(e.expr) + def quarter(e: Column): Column = withExpr { Quarter(e.expr) } /** * Extracts the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def month(e: Column): Column = Month(e.expr) + def month(e: Column): Column = withExpr { Month(e.expr) } /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) } /** * Extracts the day of the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofyear(e: Column): Column = DayOfYear(e.expr) + def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) } /** * Extracts the hours as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def hour(e: Column): Column = Hour(e.expr) + def hour(e: Column): Column = withExpr { Hour(e.expr) } /** * Given a date column, returns the last day of the month which the given date belongs to. @@ -2059,21 +2301,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def last_day(e: Column): Column = LastDay(e.expr) + def last_day(e: Column): Column = withExpr { LastDay(e.expr) } /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def minute(e: Column): Column = Minute(e.expr) + def minute(e: Column): Column = withExpr { Minute(e.expr) } /* * Returns number of months between dates `date1` and `date2`. * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + def months_between(date1: Column, date2: Column): Column = withExpr { + MonthsBetween(date1.expr, date2.expr) + } /** * Given a date column, returns the first date which is later than the value of the date column @@ -2088,21 +2332,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + def next_day(date: Column, dayOfWeek: String): Column = withExpr { + NextDay(date.expr, lit(dayOfWeek).expr) + } /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def second(e: Column): Column = Second(e.expr) + def second(e: Column): Column = withExpr { Second(e.expr) } /** * Extracts the week number as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def weekofyear(e: Column): Column = WeekOfYear(e.expr) + def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2111,7 +2357,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def from_unixtime(ut: Column): Column = withExpr { + FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2120,14 +2368,18 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + def from_unixtime(ut: Column, f: String): Column = withExpr { + FromUnixTime(ut.expr, Literal(f)) + } /** * Gets current Unix timestamp in seconds. * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(): Column = withExpr { + UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), @@ -2135,7 +2387,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(s: Column): Column = withExpr { + UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Convert time string with given pattern @@ -2144,7 +2398,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } /** * Converts the column into DateType. @@ -2152,7 +2406,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = ToDate(e.expr) + def to_date(e: Column): Column = withExpr { ToDate(e.expr) } /** * Returns date truncated to the unit specified by the format. @@ -2163,34 +2417,71 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + def trunc(date: Column, format: String): Column = withExpr { + TruncDate(date.expr, Literal(format)) + } /** * Assumes given timestamp is UTC and converts to given timezone. * @group datetime_funcs * @since 1.5.0 */ - def from_utc_timestamp(ts: Column, tz: String): Column = - FromUTCTimestamp(ts.expr, Literal(tz).expr) + def from_utc_timestamp(ts: Column, tz: String): Column = withExpr { + FromUTCTimestamp(ts.expr, Literal(tz)) + } /** * Assumes given timestamp is in given timezone and converts to UTC. * @group datetime_funcs * @since 1.5.0 */ - def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + def to_utc_timestamp(ts: Column, tz: String): Column = withExpr { + ToUTCTimestamp(ts.expr, Literal(tz)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = withExpr { + ArrayContains(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * * @group collection_funcs * @since 1.3.0 */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = withExpr { Explode(e.expr) } + + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + + /** + * Creates a new row for a json column according to the given field names. + * + * @group collection_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def json_tuple(json: Column, fields: String*): Column = withExpr { + require(fields.nonEmpty, "at least 1 field name should be given.") + JsonTuple(json.expr +: fields.map(Literal.apply)) + } /** * Returns length of array or map. @@ -2198,7 +2489,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = Size(e.expr) + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, @@ -2216,7 +2507,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2256,11 +2547,10 @@ object functions { * @deprecated As of 1.5.0, since it's redundant with udf() */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2395,147 +2685,157 @@ object functions { } ////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = { + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. + */ + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2556,7 +2856,8 @@ object functions { * @group udf_funcs * @since 1.5.0 */ - def callUDF(udfName: String, cols: Column*): Column = { + @scala.annotation.varargs + def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } @@ -2574,10 +2875,11 @@ object functions { * * @group udf_funcs * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. + * This will be removed in Spark 2.0. */ - @deprecated("Use callUDF", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = { + @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") + def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). val exprs = new Array[Expression](cols.size) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala new file mode 100644 index 000000000000..467d8d62d1b7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{DataType, MetadataBuilder} + +/** + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * + * @param dialects List of dialects. + */ +private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala new file mode 100644 index 000000000000..b1cb0e55026b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{BooleanType, StringType, DataType} + + +private object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala new file mode 100644 index 000000000000..84f68e779c38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object DerbyDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case t: DecimalType if t.precision > 31 => + Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8849fc2f1f0e..13db141f27db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.Connection import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -53,7 +53,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -abstract class JdbcDialect { +abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. * @param url the jdbc url. @@ -88,6 +88,26 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** @@ -104,11 +124,10 @@ abstract class JdbcDialect { @DeveloperApi object JdbcDialects { - private var dialects = List[JdbcDialect]() - /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. * Readding an existing dialect will cause a move-to-front. + * * @param dialect The new dialect. */ def registerDialect(dialect: JdbcDialect) : Unit = { @@ -117,14 +136,21 @@ object JdbcDialects { /** * Unregister a dialect. Does nothing if the dialect is not registered. + * * @param dialect The jdbc dialect. */ def unregisterDialect(dialect : JdbcDialect) : Unit = { dialects = dialects.filterNot(_ == dialect) } + private[this] var dialects = List[JdbcDialect]() + registerDialect(MySQLDialect) registerDialect(PostgresDialect) + registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + registerDialect(DerbyDialect) + registerDialect(OracleDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -140,85 +166,8 @@ object JdbcDialects { } /** - * :: DeveloperApi :: - * AggregatedDialect can unify multiple dialects into one virtual Dialect. - * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. - * @param dialects List of dialects. - */ -@DeveloperApi -class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - - require(dialects.nonEmpty) - - override def canHandle(url : String): Boolean = - dialects.map(_.canHandle(url)).reduce(_ && _) - - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = { - dialects.flatMap(_.getJDBCType(dt)).headOption - } -} - -/** - * :: DeveloperApi :: * NOOP dialect object, always returning the neutral element. */ -@DeveloperApi -case object NoopDialect extends JdbcDialect { +private object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } - -/** - * :: DeveloperApi :: - * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. - */ -@DeveloperApi -case object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Some(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Some(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Some(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default mysql dialect to read bit/bitsets correctly. - */ -@DeveloperApi -case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - Some(LongType) - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - Some(BooleanType) - } else None - } - - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala new file mode 100644 index 000000000000..3eb722b070d5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ + + +private object MsSqlServerDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Option(StringType) + } else { + None + } + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala new file mode 100644 index 000000000000..da413ed1f08b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types.{BooleanType, LongType, DataType, MetadataBuilder} + + +private case object MySQLDialect extends JdbcDialect { + + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Option(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Option(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala new file mode 100644 index 000000000000..4165c382689f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object OracleDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + if (sqlType == Types.NUMERIC && size == 0) { + // This is sub-optimal as we have to pick a precision/scale in advance whereas the data + // in Oracle is allowed to have different precision/scale for each value. + Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else { + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala new file mode 100644 index 000000000000..3cf80f576e92 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, Types} + +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.types._ + + +private object PostgresDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER) { + toCatalystType(typeName).filter(_ == StringType) + } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { + toCatalystType(typeName.drop(1)).map(ArrayType(_)) + } else None + } + + // TODO: support more type names. + private def toCatalystType(typeName: String): Option[DataType] = typeName match { + case "bool" => Some(BooleanType) + case "bit" => Some(BinaryType) + case "int2" => Some(ShortType) + case "int4" => Some(IntegerType) + case "int8" | "oid" => Some(LongType) + case "float4" => Some(FloatType) + case "money" | "float8" => Some(DoubleType) + case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => + Some(StringType) + case "bytea" => Some(BinaryType) + case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) + case "date" => Some(DateType) + case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case _ => None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + getJDBCType(et).map(_.databaseTypeDefinition) + .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) + .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) + case _ => None + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e0510080f..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala deleted file mode 100644 index 562b058414d0..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.IOException - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} - - -private[sql] class DefaultSource - extends RelationProvider - with SchemaRelationProvider - with CreatableRelationProvider { - - private def checkPath(parameters: Map[String, String]): String = { - parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) - } - - /** Constraints to be imposed on dataframe to be stored. */ - private def checkConstraints(data: DataFrame): Unit = { - if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { - val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } - - /** Returns a new base relation with the parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val path = checkPath(parameters) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - new JSONRelation(path, samplingRatio, None, sqlContext) - } - - /** Returns a new base relation with the given schema and parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { - val path = checkPath(parameters) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - new JSONRelation(path, samplingRatio, Some(schema), sqlContext) - } - - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - // check if dataframe satisfies the constraints - // before moving forward - checkConstraints(data) - - val path = checkPath(parameters) - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doSave = if (fs.exists(filesystemPath)) { - mode match { - case SaveMode.Append => - sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") - case SaveMode.Overwrite => { - JSONRelation.delete(filesystemPath, fs) - true - } - case SaveMode.ErrorIfExists => - sys.error(s"path $path already exists.") - case SaveMode.Ignore => false - } - } else { - true - } - if (doSave) { - // Only save data when the save mode is not ignore. - data.toJSON.saveAsTextFile(path) - } - - createRelation(sqlContext, parameters, data.schema) - } -} - -private[sql] class JSONRelation( - // baseRDD is not immutable with respect to INSERT OVERWRITE - // and so it must be recreated at least as often as the - // underlying inputs are modified. To be safe, a function is - // used instead of a regular RDD value to ensure a fresh RDD is - // recreated for each and every operation. - baseRDD: () => RDD[String], - val path: Option[String], - val samplingRatio: Double, - userSpecifiedSchema: Option[StructType])( - @transient val sqlContext: SQLContext) - extends BaseRelation - with TableScan - with InsertableRelation - with CatalystScan { - - def this( - path: String, - samplingRatio: Double, - userSpecifiedSchema: Option[StructType], - sqlContext: SQLContext) = - this( - () => sqlContext.sparkContext.textFile(path), - Some(path), - samplingRatio, - userSpecifiedSchema)(sqlContext) - - /** Constraints to be imposed on dataframe to be stored. */ - private def checkConstraints(data: DataFrame): Unit = { - if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { - val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } - - override val needConversion: Boolean = false - - override lazy val schema = userSpecifiedSchema.getOrElse { - InferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) - } - - override def buildScan(): RDD[Row] = { - // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row] - JacksonParser( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] - } - - override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { - // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - JacksonParser( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] - } - - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - // check if dataframe satisfies constraints - // before moving forward - checkConstraints(data) - - val filesystemPath = path match { - case Some(p) => new Path(p) - case None => - throw new IOException(s"Cannot INSERT into table with no path defined") - } - - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - - if (overwrite) { - if (fs.exists(filesystemPath)) { - JSONRelation.delete(filesystemPath, fs) - } - // Write the data. - data.toJSON.saveAsTextFile(filesystemPath.toString) - // Right now, we assume that the schema is not changed. We will not update the schema. - // schema = data.schema - } else { - // TODO: Support INSERT INTO - sys.error("JSON table only support INSERT OVERWRITE for now.") - } - } - - override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode() - - override def equals(other: Any): Boolean = other match { - case that: JSONRelation => - (this.path == that.path) && this.schema.sameType(that.schema) - case _ => false - } -} - -private object JSONRelation { - - /** Delete the specified directory to overwrite it with new JSON data. */ - def delete(dir: Path, fs: FileSystem): Unit = { - var success: Boolean = false - val failMessage = s"Unable to clear output directory $dir prior to writing to JSON table" - try { - success = fs.delete(dir, true /* recursive */) - } catch { - case e: IOException => - throw new IOException(s"$failMessage\n${e.toString}") - } - if (!success) { - throw new IOException(failMessage) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala deleted file mode 100644 index 5a4dde575696..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * An optimization rule used to insert Filters to filter out rows whose equal join keys - * have at least one null values. For this kind of rows, they will not contribute to - * the join results of equal joins because a null does not equal another null. We can - * filter them out before shuffling join input rows. For example, we have two tables - * - * table1(key String, value Int) - * "str1"|1 - * null |2 - * - * table2(key String, value Int) - * "str1"|3 - * null |4 - * - * For a inner equal join, the result will be - * "str1"|1|"str1"|3 - * - * those two rows having null as the value of key will not contribute to the result. - * So, we can filter them out early. - * - * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. - * - */ -case class FilterNullsInJoinKey( - sqlContext: SQLContext) - extends Rule[LogicalPlan] { - - /** - * Checks if we need to add a Filter operator. We will add a Filter when - * there is any attribute in `keys` whose corresponding attribute of `keys` - * in `plan.output` is still nullable (`nullable` field is `true`). - */ - private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { - val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) - plan.output.filter(keyAttributeSet.contains).exists(_.nullable) - } - - /** - * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. - */ - private def addFilterIfNecessary( - keys: Seq[Expression], - child: LogicalPlan): LogicalPlan = { - // We get all attributes from keys. - val attributes = keys.filter(_.isInstanceOf[Attribute]) - - // Then, we create a Filter to make sure these attributes are non-nullable. - val filter = - if (attributes.nonEmpty) { - Filter(Not(AtLeastNNulls(1, attributes)), child) - } else { - child - } - - filter - } - - /** - * We reconstruct the join condition. - */ - private def reconstructJoinCondition( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - otherPredicate: Option[Expression]): Expression = { - // First, we rewrite the equal condition part. When we extract those keys, - // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). - val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { - case (l, r) => EqualTo(l, r) - }.reduce(And) - - // Then, we add otherPredicate. When we extract those equal condition part, - // we use splitConjunctivePredicates. So, it is safe to use - // And(rewrittenEqualJoinCondition, c). - val rewrittenJoinCondition = otherPredicate - .map(c => And(rewrittenEqualJoinCondition, c)) - .getOrElse(rewrittenEqualJoinCondition) - - rewrittenJoinCondition - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (!sqlContext.conf.advancedSqlOptimizations) { - plan - } else { - plan transform { - case join: Join => join match { - // For a inner join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) - - // For a left outer join having equal join condition part, we can add a filter - // to the right side of the join operator. - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(rightKeys, right) => - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) - - // For a right outer join having equal join condition part, we can add a filter - // to the left side of the join operator. - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) - - // For a left semi join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) - - case other => other - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala deleted file mode 100644 index 975fec101d9c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.util.{Map => JMap} - -import scala.collection.JavaConversions.{iterableAsScalaIterable, mapAsJavaMap, mapAsScalaMap} - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} -import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.MessageType - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.StructType - -private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { - override def prepareForRead( - conf: Configuration, - keyValueMetaData: JMap[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - log.debug(s"Preparing for read Parquet file with message type: $fileSchema") - - val toCatalyst = new CatalystSchemaConverter(conf) - val parquetRequestedSchema = readContext.getRequestedSchema - - val catalystRequestedSchema = - Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => - metadata - // First tries to read requested schema, which may result from projections - .get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - // If not available, tries to read Catalyst schema from file metadata. It's only - // available if the target file is written by Spark SQL. - .orElse(metadata.get(CatalystReadSupport.SPARK_METADATA_KEY)) - }.map(StructType.fromString).getOrElse { - logDebug("Catalyst schema not available, falling back to Parquet schema") - toCatalyst.convert(parquetRequestedSchema) - } - - logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") - new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) - } - - override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration - - // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst - // schema of this file from its the metadata. - val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) - - // Optional schema of requested columns, in the form of a string serialized from a Catalyst - // `StructType` containing all requested columns. - val maybeRequestedSchema = Option(conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - - // Below we construct a Parquet schema containing all requested columns. This schema tells - // Parquet which columns to read. - // - // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, - // we have to fallback to the full file schema which contains all columns in the file. - // Obviously this may waste IO bandwidth since it may read more columns than requested. - // - // Two things to note: - // - // 1. It's possible that some requested columns don't exist in the target Parquet file. For - // example, in the case of schema merging, the globally merged schema may contain extra - // columns gathered from other Parquet files. These columns will be simply filled with nulls - // when actually reading the target Parquet file. - // - // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to - // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to - // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file - // containing a single integer array field `f1` may have the following legacy 2-level - // structure: - // - // message root { - // optional group f1 (LIST) { - // required INT32 element; - // } - // } - // - // while `CatalystSchemaConverter` may generate a standard 3-level structure: - // - // message root { - // optional group f1 (LIST) { - // repeated group list { - // required INT32 element; - // } - // } - // } - // - // Apparently, we can't use the 2nd schema to read the target Parquet file as they have - // different physical structures. - val parquetRequestedSchema = - maybeRequestedSchema.fold(context.getFileSchema) { schemaString => - val toParquet = new CatalystSchemaConverter(conf) - val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.map(_.getName).toSet - - StructType - // Deserializes the Catalyst schema of requested columns - .fromString(schemaString) - .map { field => - if (fileFieldNames.contains(field.name)) { - // If the field exists in the target Parquet file, extracts the field type from the - // full file schema and makes a single-field Parquet schema - new MessageType("root", fileSchema.getType(field.name)) - } else { - // Otherwise, just resorts to `CatalystSchemaConverter` - toParquet.convert(StructType(Array(field))) - } - } - // Merges all single-field Parquet schemas to form a complete schema for all requested - // columns. Note that it's possible that no columns are requested at all (e.g., count - // some partition column of a partitioned Parquet table). That's why `fold` is used here - // and always fallback to an empty Parquet schema. - .fold(new MessageType("root")) { - _ union _ - } - } - - val metadata = - Map.empty[String, String] ++ - maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ - maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - - logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") - new ReadContext(parquetRequestedSchema, metadata) - } -} - -private[parquet] object CatalystReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala deleted file mode 100644 index 6ed3580af072..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{MapData, ArrayData} - -// TODO Removes this while fixing SPARK-8848 -private[sql] object CatalystConverter { - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). - // Note that "array" for the array elements is chosen by ParquetAvro. - // Using a different value will result in Parquet silently dropping columns. - val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" - val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - - val MAP_KEY_SCHEMA_NAME = "key" - val MAP_VALUE_SCHEMA_NAME = "value" - val MAP_SCHEMA_NAME = "map" - - // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType = ArrayData - type StructScalaType = InternalRow - type MapScalaType = MapData -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala deleted file mode 100644 index 9cd0250f9c51..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ /dev/null @@ -1,322 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.math.BigInteger -import java.nio.{ByteBuffer, ByteOrder} -import java.util.{HashMap => JHashMap} - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.column.ParquetProperties -import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.parquet.hadoop.api.WriteSupport -import org.apache.parquet.io.api._ - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A `parquet.hadoop.api.WriteSupport` for Row objects. - */ -private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { - - private[parquet] var writer: RecordConsumer = null - private[parquet] var attributes: Array[Attribute] = null - - override def init(configuration: Configuration): WriteSupport.WriteContext = { - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - val metadata = new JHashMap[String, String]() - metadata.put(CatalystReadSupport.SPARK_METADATA_KEY, origAttributesStr) - - if (attributes == null) { - attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray - } - - log.debug(s"write support initialized for requested schema $attributes") - ParquetRelation.enableLogForwarding() - new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) - } - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - writer = recordConsumer - log.debug(s"preparing for write with schema $attributes") - } - - override def write(record: InternalRow): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.numFields) { - throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + - s"($attributesSize > ${record.numFields})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (!record.isNullAt(index)) { - writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private[parquet] def writeValue(schema: DataType, value: Any): Unit = { - if (value != null) { - schema match { - case t: UserDefinedType[_] => writeValue(t.sqlType, value) - case t @ ArrayType(_, _) => writeArray( - t, - value.asInstanceOf[CatalystConverter.ArrayScalaType]) - case t @ MapType(_, _, _) => writeMap( - t, - value.asInstanceOf[CatalystConverter.MapScalaType]) - case t @ StructType(_) => writeStruct( - t, - value.asInstanceOf[CatalystConverter.StructScalaType]) - case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) - } - } - } - - private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { - if (value != null) { - schema match { - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case ShortType => writer.addInteger(value.asInstanceOf[Short]) - case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) - case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[Long]) - case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) - case StringType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case DecimalType.Fixed(precision, _) => - writeDecimal(value.asInstanceOf[Decimal], precision) - case _ => sys.error(s"Do not know how to writer $schema to consumer") - } - } - } - - private[parquet] def writeStruct( - schema: StructType, - struct: CatalystConverter.StructScalaType): Unit = { - if (struct != null) { - val fields = schema.fields.toArray - writer.startGroup() - var i = 0 - while(i < fields.length) { - if (!struct.isNullAt(i)) { - writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) - writer.endField(fields(i).name, i) - } - i = i + 1 - } - writer.endGroup() - } - } - - private[parquet] def writeArray( - schema: ArrayType, - array: CatalystConverter.ArrayScalaType): Unit = { - val elementType = schema.elementType - writer.startGroup() - if (array.numElements() > 0) { - if (schema.containsNull) { - writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - var i = 0 - while (i < array.numElements()) { - writer.startGroup() - if (!array.isNullAt(i)) { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array.get(i, elementType)) - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - writer.endGroup() - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - } else { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - var i = 0 - while (i < array.numElements()) { - writeValue(elementType, array.get(i, elementType)) - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - } - writer.endGroup() - } - - private[parquet] def writeMap( - schema: MapType, - map: CatalystConverter.MapScalaType): Unit = { - writer.startGroup() - val length = map.numElements() - if (length > 0) { - writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) - map.foreach(schema.keyType, schema.valueType, (key, value) => { - writer.startGroup() - writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - writeValue(schema.keyType, key) - writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - if (value != null) { - writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - writeValue(schema.valueType, value) - writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - } - writer.endGroup() - }) - writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) - } - writer.endGroup() - } - - // Scratch array used to write decimals as fixed-length byte array - private[this] var reusableDecimalBytes = new Array[Byte](16) - - private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) - - def longToBinary(unscaled: Long): Binary = { - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - reusableDecimalBytes(i) = (unscaled >> shift).toByte - i += 1 - shift -= 8 - } - Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) - } - - def bigIntegerToBinary(unscaled: BigInteger): Binary = { - unscaled.toByteArray match { - case bytes if bytes.length == numBytes => - Binary.fromByteArray(bytes) - - case bytes if bytes.length <= reusableDecimalBytes.length => - val signedByte = (if (bytes.head < 0) -1 else 0).toByte - java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) - System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) - Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) - - case bytes => - reusableDecimalBytes = new Array[Byte](bytes.length) - bigIntegerToBinary(unscaled) - } - } - - val binary = if (numBytes <= 8) { - longToBinary(decimal.toUnscaledLong) - } else { - bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) - } - - writer.addBinary(binary) - } - - // array used to write Timestamp as Int96 (fixed-length binary) - private[this] val int96buf = new Array[Byte](12) - - private[parquet] def writeTimestamp(ts: Long): Unit = { - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts) - val buf = ByteBuffer.wrap(int96buf) - buf.order(ByteOrder.LITTLE_ENDIAN) - buf.putLong(timeOfDayNanos) - buf.putInt(julianDay) - writer.addBinary(Binary.fromByteArray(int96buf)) - } -} - -// Optimized for non-nested rows -private[parquet] class MutableRowWriteSupport extends RowWriteSupport { - override def write(record: InternalRow): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.numFields) { - throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + - s"($attributesSize > ${record.numFields})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (!record.isNullAt(index) && !record.isNullAt(index)) { - writer.startField(attributes(index).name, index) - consumeType(attributes(index).dataType, record, index) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private def consumeType( - ctype: DataType, - record: InternalRow, - index: Int): Unit = { - ctype match { - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case IntegerType | DateType => writer.addInteger(record.getInt(index)) - case LongType => writer.addLong(record.getLong(index)) - case TimestampType => writeTimestamp(record.getLong(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case StringType => - writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) - case BinaryType => - writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, scale) => - writeDecimal(record.getDecimal(index, precision, scale), precision) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } -} - -private[parquet] object RowWriteSupport { - val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - - def getSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (schemaString == null) { - throw new RuntimeException("Missing schema!") - } - ParquetTypesConverter.convertFromString(schemaString) - } - - def setSchema(schema: Seq[Attribute], configuration: Configuration) { - val encoded = ParquetTypesConverter.convertToString(schema) - configuration.set(SPARK_ROW_SCHEMA, encoded) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala deleted file mode 100644 index 3854f5bd39fb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException - -import scala.collection.JavaConversions._ -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.MessageType - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types._ - - -private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true - case _ => false - } - - /** - * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. - */ - private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => - var length = 1 - while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { - length += 1 - } - length - } - - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val converter = new CatalystSchemaConverter() - converter.convert(StructType.fromAttributes(attributes)) - } - - def convertFromString(string: String): Seq[Attribute] = { - Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { - case s: StructType => s.toAttributes - case other => sys.error(s"Can convert $string to row") - } - } - - def convertToString(schema: Seq[Attribute]): String = { - schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) - StructType.fromAttributes(schema).json - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put( - CatalystReadSupport.SPARK_METADATA_KEY, - ParquetTypesConverter.convertToString(attributes)) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @param configuration The Hadoop configuration to use. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - - val children = - fs - .globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .filterNot { status => - val name = status.getPath.getName - (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE - } - - ParquetRelation.enableLogForwarding() - - // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row - // groups. Since Parquet schema is replicated among all row groups, we only need to touch a - // single row group to read schema related metadata. Notice that we are making assumptions that - // all data in a single Parquet file have the same schema, which is normally true. - children - // Try any non-"_metadata" file first... - .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) - // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is - // empty, thus normally the "_metadata" file is expected to be fairly small). - .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) - .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) - .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4d942e4f9287..3780cbbcc963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -36,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7126145ddc01..fc8ce6901dfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -31,12 +32,38 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 1.5.0 + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + def shortName(): String +} + /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When @@ -48,7 +75,7 @@ import org.apache.spark.util.SerializableConfiguration * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * @since 1.3.0 */ @@ -74,7 +101,7 @@ trait RelationProvider { * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that * users need to provide a schema when using a [[SchemaRelationProvider]]. @@ -109,7 +136,7 @@ trait SchemaRelationProvider { * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when @@ -169,7 +196,7 @@ trait CreatableRelationProvider { * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * - * BaseRelations must also define a equality function that only returns true when the two + * BaseRelations must also define an equality function that only returns true when the two * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. * @@ -182,7 +209,7 @@ abstract class BaseRelation { /** * Returns an estimated size of this relation in bytes. This information is used by the planner - * to decided when it is safe to broadcast a relation and can be overridden by sources that + * to decide when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. @@ -207,6 +234,17 @@ abstract class BaseRelation { * @since 1.4.0 */ def needConversion: Boolean = true + + /** + * Returns the list of [[Filter]]s that this datasource may not be able to handle. + * These returned [[Filter]]s will be evaluated by Spark SQL after data is output by a scan. + * By default, this function will return all filters, as it is always safe to + * double evaluate a [[Filter]]. However, specific implementations can override this function to + * avoid double filtering when they are capable of processing a filter internally. + * + * @since 1.6.0 + */ + def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters } /** @@ -342,23 +380,22 @@ abstract class OutputWriter { * @since 1.4.0 */ def close(): Unit -} -/** - * This is an internal, private version of [[OutputWriter]] with an writeInternal method that - * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have - * the conversion flag set to false. - */ -private[sql] abstract class OutputWriterInternal extends OutputWriter { + private var converter: InternalRow => Row = _ - override def write(row: Row): Unit = throw new UnsupportedOperationException + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } - def writeInternal(row: InternalRow): Unit + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } } /** * ::Experimental:: - * A [[BaseRelation]] that provides much of the common code required for formats that store their + * A [[BaseRelation]] that provides much of the common code required for relations that store their * data to an HDFS compatible filesystem. * * For the read path, similar to [[PrunedFilteredScan]], it can eliminate unneeded columns and @@ -380,25 +417,30 @@ private[sql] abstract class OutputWriterInternal extends OutputWriter { * @since 1.4.0 */ @Experimental -abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with Logging { +abstract class HadoopFsRelation private[sql]( + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String]) + extends BaseRelation with FileRelation with Logging { - logInfo("Constructing HadoopFsRelation") + override def toString: String = getClass.getSimpleName - def this() = this(None) + def this() = this(None, Map.empty[String, String]) - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + def this(parameters: Map[String, String]) = this(None, parameters) - private val codegenEnabled = sqlContext.conf.codegenEnabled + private[sql] def this(maybePartitionSpec: Option[PartitionSpec]) = + this(maybePartitionSpec, Map.empty[String, String]) + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) private var _partitionSpec: PartitionSpec = _ private class FileStatusCache { - var leafFiles = mutable.Map.empty[Path, FileStatus] + var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { + private def listLeafFiles(paths: Array[String]): mutable.LinkedHashSet[FileStatus] = { if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) } else { @@ -406,9 +448,15 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -416,10 +464,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val (dirs, files) = statuses.partition(_.isDir) + // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) if (dirs.isEmpty) { - files.toSet + mutable.LinkedHashSet(files: _*) } else { - files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) + mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath.toString)) } } } @@ -430,7 +479,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio leafFiles.clear() leafDirToChildrenFiles.clear() - leafFiles ++= files.map(f => f.getPath -> f).toMap + leafFiles ++= files.map(f => f.getPath -> f) leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) } } @@ -441,8 +490,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio cache } - protected def cachedLeafStatuses(): Set[FileStatus] = { - fileStatusCache.leafFiles.values.toSet + protected def cachedLeafStatuses(): mutable.LinkedHashSet[FileStatus] = { + mutable.LinkedHashSet(fileStatusCache.leafFiles.values.toArray: _*) } final private[sql] def partitionSpec: PartitionSpec = { @@ -461,8 +510,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() @@ -484,13 +533,41 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } /** - * Base paths of this relation. For partitioned relations, it should be either root directories + * Paths of this relation. For partitioned relations, it should be root directories * of all partition directories. * * @since 1.4.0 */ def paths: Array[String] + /** + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. By default, the paths of the dataset provided by users will be base paths. + * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path + * will be `/path/something=true/`, and the returned DataFrame will not contain a column of + * `something`. If users want to override the basePath. They can set `basePath` in the options + * to pass the new base path to the data source. + * For the above example, if the user-provided base path is `/path/`, the returned + * DataFrame will have the column of `something`. + */ + private def basePaths: Set[Path] = { + val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) + userDefinedBasePath.getOrElse { + // If the user does not provide basePath, we will just use paths. + val pathSet = paths.toSet + pathSet.map(p => new Path(p)) + }.map { hdfsPath => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = hdfsPath.getFileSystem(hadoopConf) + hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + } + + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + + override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. @@ -515,11 +592,38 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { - val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference) + userDefinedPartitionColumns match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getUTF8String(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + + case _ => + // user did not provide a partitioning schema + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), + basePaths = basePaths) + } } /** @@ -535,11 +639,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sql] final def buildScan( + final private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) @@ -554,7 +658,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } } - buildScan(requiredColumns, filters, inputStatuses, broadcastedConf) + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf) } /** @@ -601,7 +705,6 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema - val codegenEnabled = this.codegenEnabled val needConversion = this.needConversion val requiredOutput = requiredColumns.map { col => @@ -618,11 +721,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled) { + val buildProjection = GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) - } else { - () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) - } val projectedRows = { val mutableProjection = buildProjection() @@ -690,6 +790,44 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio buildScan(requiredColumns, filters, inputFiles) } + /** + * For a non-partitioned relation, this method builds an `RDD[InternalRow]` containing all rows + * within this relation. For partitioned relations, this method is called for each selected + * partition, and builds an `RDD[InternalRow]` containing all rows within that single partition. + * + * Note: + * + * 1. Rows contained in the returned `RDD[InternalRow]` are assumed to be `UnsafeRow`s. + * 2. This interface is subject to change in future. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the + * overhead of broadcasting the Configuration for every Hadoop RDD. + */ + private[sql] def buildInternalScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { + val requiredSchema = StructType(requiredColumns.map(dataSchema.apply)) + val internalRows = { + val externalRows = buildScan(requiredColumns, filters, inputFiles, broadcastedConf) + execution.RDDConversions.rowToRowRdd(externalRows, requiredSchema.map(_.dataType)) + } + + internalRows.mapPartitions { iterator => + val unsafeProjection = UnsafeProjection.create(requiredSchema) + iterator.map(unsafeProjection) + } + } + /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here @@ -716,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } @@ -737,7 +883,7 @@ private[sql] object HadoopFsRelation extends Logging { def listLeafFilesInParallel( paths: Array[String], hadoopConf: Configuration, - sparkContext: SparkContext): Set[FileStatus] = { + sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") val serializableConfiguration = new SerializableConfiguration(hadoopConf) @@ -757,9 +903,10 @@ private[sql] object HadoopFsRelation extends Logging { status.getAccessTime) }.collect() - fakeStatuses.map { f => + val hadoopFakeStatuses = fakeStatuses.map { f => new FileStatus( f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) - }.toSet + } + mutable.LinkedHashSet(hadoopFakeStatuses: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 2fdd798b44bb..8d4854b698ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.test -import java.util - -import scala.collection.JavaConverters._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ /** @@ -39,22 +37,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): GenericArrayData = { obj match { case p: ExamplePoint => - Seq(p.x, p.y) + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } } override def deserialize(datum: Any): ExamplePoint = { datum match { - case values: Seq[_] => - val xy = values.asInstanceOf[Seq[Double]] - assert(xy.length == 2) - new ExamplePoint(xy(0), xy(1)) - case values: util.ArrayList[_] => - val xy = values.asInstanceOf[util.ArrayList[Double]].asScala - new ExamplePoint(xy(0), xy(1)) + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala deleted file mode 100644 index b3a4231da91c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - } - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } - -} - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala new file mode 100644 index 000000000000..ac432e2baa3c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.mutable.ListBuffer +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.sql.execution.QueryExecution + + +/** + * :: Experimental :: + * The interface of query execution listener that can be used to analyze execution metrics. + * + * Note that implementations should guarantee thread-safety as they can be invoked by + * multiple different threads. + */ +@Experimental +trait QueryExecutionListener { + + /** + * A callback function that will be called when a query executed successfully. + * Note that this can be invoked by multiple different threads. + * + * @param funcName name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param durationNs the execution time for this query in nanoseconds. + */ + @DeveloperApi + def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit + + /** + * A callback function that will be called when a query execution failed. + * Note that this can be invoked by multiple different threads. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param exception the exception that failed this query. + */ + @DeveloperApi + def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit +} + + +/** + * :: Experimental :: + * + * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. + */ +@Experimental +class ExecutionListenerManager private[sql] () extends Logging { + + /** + * Registers the specified [[QueryExecutionListener]]. + */ + @DeveloperApi + def register(listener: QueryExecutionListener): Unit = writeLock { + listeners += listener + } + + /** + * Unregisters the specified [[QueryExecutionListener]]. + */ + @DeveloperApi + def unregister(listener: QueryExecutionListener): Unit = writeLock { + listeners -= listener + } + + /** + * Removes all the registered [[QueryExecutionListener]]. + */ + @DeveloperApi + def clear(): Unit = writeLock { + listeners.clear() + } + + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } + } + } + + private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } + } + } + + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + + /** A lock to prevent updating the list of listeners while we are traversing through them. */ + private[this] val lock = new ReentrantReadWriteLock() + + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { + for (listener <- listeners) { + try { + f(listener) + } catch { + case NonFatal(e) => logWarning("Error executing query execution listener", e) + } + } + } + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } +} diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md index 3dd9861b4896..421c2ea4f7ae 100644 --- a/sql/core/src/test/README.md +++ b/sql/core/src/test/README.md @@ -6,23 +6,19 @@ The following directories and files are used for Parquet compatibility tests: . ├── README.md # This file ├── avro -│   ├── parquet-compat.avdl # Testing Avro IDL -│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ├── gen-java # !! NO TOUCH !! Generated Java code ├── scripts -│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift └── thrift - └── parquet-compat.thrift # Testing Thrift schema + └── *.thrift # Testing Thrift schema(s) ``` -Generated Java code are used in the following test suites: - -- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` -- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` - To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. -When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. ## Prerequisites diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 24729f6143e6..c5eb5b5164cf 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -16,14 +16,25 @@ */ // This is a test protocol for testing parquet-avro compatibility. -@namespace("org.apache.spark.sql.parquet.test.avro") +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + record Nested { array nested_ints_column; string nested_string_column; } - record ParquetAvroCompat { + record AvroPrimitives { boolean bool_column; int int_column; long long_column; @@ -31,7 +42,9 @@ protocol CompatibilityTest { double double_column; bytes binary_column; string string_column; + } + record AvroOptionalPrimitives { union { null, boolean } maybe_bool_column; union { null, int } maybe_int_column; union { null, long } maybe_long_column; @@ -39,7 +52,22 @@ protocol CompatibilityTest { union { null, double } maybe_double_column; union { null, bytes } maybe_binary_column; union { null, string } maybe_string_column; + } + record AvroNonNullableArrays { + array strings_column; + union { null, array } maybe_ints_column; + } + + record AvroArrayOfArray { + array> int_arrays_column; + } + + record AvroMapOfArray { + map> string_to_ints_column; + } + + record ParquetAvroCompat { array strings_column; map string_to_int_column; map> complex_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index a83b7c990dd2..9ad315b74fb4 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -1,7 +1,18 @@ { "protocol" : "CompatibilityTest", - "namespace" : "org.apache.spark.sql.parquet.test.avro", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { "type" : "record", "name" : "Nested", "fields" : [ { @@ -16,7 +27,7 @@ } ] }, { "type" : "record", - "name" : "ParquetAvroCompat", + "name" : "AvroPrimitives", "fields" : [ { "name" : "bool_column", "type" : "boolean" @@ -38,7 +49,11 @@ }, { "name" : "string_column", "type" : "string" - }, { + } ] + }, { + "type" : "record", + "name" : "AvroOptionalPrimitives", + "fields" : [ { "name" : "maybe_bool_column", "type" : [ "null", "boolean" ] }, { @@ -59,7 +74,53 @@ }, { "name" : "maybe_string_column", "type" : [ "null", "string" ] + } ] + }, { + "type" : "record", + "name" : "AvroNonNullableArrays", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } }, { + "name" : "maybe_ints_column", + "type" : [ "null", { + "type" : "array", + "items" : "int" + } ] + } ] + }, { + "type" : "record", + "name" : "AvroArrayOfArray", + "fields" : [ { + "name" : "int_arrays_column", + "type" : { + "type" : "array", + "items" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "AvroMapOfArray", + "fields" : [ { + "name" : "string_to_ints_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { "name" : "strings_column", "type" : { "type" : "array", diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java new file mode 100644 index 000000000000..ee327827903e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroArrayOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List> int_arrays_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroArrayOfArray() {} + + /** + * All-args constructor. + */ + public AvroArrayOfArray(java.util.List> int_arrays_column) { + this.int_arrays_column = int_arrays_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return int_arrays_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: int_arrays_column = (java.util.List>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'int_arrays_column' field. + */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** + * Sets the value of the 'int_arrays_column' field. + * @param value the value to set. + */ + public void setIntArraysColumn(java.util.List> value) { + this.int_arrays_column = value; + } + + /** Creates a new AvroArrayOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing AvroArrayOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroArrayOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List> int_arrays_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroArrayOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'int_arrays_column' field */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** Sets the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder setIntArraysColumn(java.util.List> value) { + validate(fields()[0], value); + this.int_arrays_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'int_arrays_column' field has been set */ + public boolean hasIntArraysColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder clearIntArraysColumn() { + int_arrays_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroArrayOfArray build() { + try { + AvroArrayOfArray record = new AvroArrayOfArray(); + record.int_arrays_column = fieldSetFlags()[0] ? this.int_arrays_column : (java.util.List>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java new file mode 100644 index 000000000000..727f6a7bf733 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroMapOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.Map> string_to_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroMapOfArray() {} + + /** + * All-args constructor. + */ + public AvroMapOfArray(java.util.Map> string_to_ints_column) { + this.string_to_ints_column = string_to_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return string_to_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: string_to_ints_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'string_to_ints_column' field. + */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** + * Sets the value of the 'string_to_ints_column' field. + * @param value the value to set. + */ + public void setStringToIntsColumn(java.util.Map> value) { + this.string_to_ints_column = value; + } + + /** Creates a new AvroMapOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing AvroMapOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroMapOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.Map> string_to_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroMapOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'string_to_ints_column' field */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** Sets the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder setStringToIntsColumn(java.util.Map> value) { + validate(fields()[0], value); + this.string_to_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'string_to_ints_column' field has been set */ + public boolean hasStringToIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder clearStringToIntsColumn() { + string_to_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroMapOfArray build() { + try { + AvroMapOfArray record = new AvroMapOfArray(); + record.string_to_ints_column = fieldSetFlags()[0] ? this.string_to_ints_column : (java.util.Map>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java new file mode 100644 index 000000000000..934793f42f9c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroNonNullableArrays extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.List maybe_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroNonNullableArrays() {} + + /** + * All-args constructor. + */ + public AvroNonNullableArrays(java.util.List strings_column, java.util.List maybe_ints_column) { + this.strings_column = strings_column; + this.maybe_ints_column = maybe_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return maybe_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: maybe_ints_column = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'maybe_ints_column' field. + */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** + * Sets the value of the 'maybe_ints_column' field. + * @param value the value to set. + */ + public void setMaybeIntsColumn(java.util.List value) { + this.maybe_ints_column = value; + } + + /** Creates a new AvroNonNullableArrays RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing AvroNonNullableArrays instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** + * RecordBuilder for AvroNonNullableArrays instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.List maybe_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing AvroNonNullableArrays instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_ints_column' field */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** Sets the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setMaybeIntsColumn(java.util.List value) { + validate(fields()[1], value); + this.maybe_ints_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_ints_column' field has been set */ + public boolean hasMaybeIntsColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearMaybeIntsColumn() { + maybe_ints_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public AvroNonNullableArrays build() { + try { + AvroNonNullableArrays record = new AvroNonNullableArrays(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.maybe_ints_column = fieldSetFlags()[1] ? this.maybe_ints_column : (java.util.List) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java new file mode 100644 index 000000000000..e4d1ead8dd15 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java @@ -0,0 +1,466 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroOptionalPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroOptionalPrimitives() {} + + /** + * All-args constructor. + */ + public AvroOptionalPrimitives(java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column) { + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return maybe_bool_column; + case 1: return maybe_int_column; + case 2: return maybe_long_column; + case 3: return maybe_float_column; + case 4: return maybe_double_column; + case 5: return maybe_binary_column; + case 6: return maybe_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: maybe_bool_column = (java.lang.Boolean)value$; break; + case 1: maybe_int_column = (java.lang.Integer)value$; break; + case 2: maybe_long_column = (java.lang.Long)value$; break; + case 3: maybe_float_column = (java.lang.Float)value$; break; + case 4: maybe_double_column = (java.lang.Double)value$; break; + case 5: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 6: maybe_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing AvroOptionalPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroOptionalPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroOptionalPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[0], value); + this.maybe_bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[1], value); + this.maybe_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[2], value); + this.maybe_long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[3], value); + this.maybe_float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[4], value); + this.maybe_double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.maybe_binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.maybe_string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroOptionalPrimitives build() { + try { + AvroOptionalPrimitives record = new AvroOptionalPrimitives(); + record.maybe_bool_column = fieldSetFlags()[0] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.maybe_int_column = fieldSetFlags()[1] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.maybe_long_column = fieldSetFlags()[2] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[2]); + record.maybe_float_column = fieldSetFlags()[3] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[3]); + record.maybe_double_column = fieldSetFlags()[4] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[4]); + record.maybe_binary_column = fieldSetFlags()[5] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.maybe_string_column = fieldSetFlags()[6] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java new file mode 100644 index 000000000000..1c2afed16781 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java @@ -0,0 +1,461 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroPrimitives() {} + + /** + * All-args constructor. + */ + public AvroPrimitives(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** Creates a new AvroPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing AvroPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroPrimitives build() { + try { + AvroPrimitives record = new AvroPrimitives(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 000000000000..28fdc1dfb911 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]},{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]},{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]},{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java similarity index 75% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index 051f1ee90386..a7bf4841919c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public java.util.List nested_ints_column; @Deprecated public java.lang.String nested_string_column; @@ -77,18 +77,18 @@ public void setNestedStringColumn(java.lang.String value) { } /** Creates a new Nested RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); } /** Creates a new Nested RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** @@ -102,11 +102,11 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { super(other); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); @@ -119,8 +119,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { } /** Creates a Builder by copying an existing Nested instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); fieldSetFlags()[0] = true; @@ -137,7 +137,7 @@ public java.util.List getNestedIntsColumn() { } /** Sets the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { validate(fields()[0], value); this.nested_ints_column = value; fieldSetFlags()[0] = true; @@ -150,7 +150,7 @@ public boolean hasNestedIntsColumn() { } /** Clears the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { nested_ints_column = null; fieldSetFlags()[0] = false; return this; @@ -162,7 +162,7 @@ public java.lang.String getNestedStringColumn() { } /** Sets the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { validate(fields()[1], value); this.nested_string_column = value; fieldSetFlags()[1] = true; @@ -175,7 +175,7 @@ public boolean hasNestedStringColumn() { } /** Clears the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { nested_string_column = null; fieldSetFlags()[1] = false; return this; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 000000000000..ef12d193f916 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,250 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return string_to_int_column; + case 2: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: string_to_int_column = (java.util.Map)value$; break; + case 2: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[1], value); + this.string_to_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[2], value); + this.complex_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[2] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.string_to_int_column = fieldSetFlags()[1] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[1]); + record.complex_column = fieldSetFlags()[2] ? this.complex_column : (java.util.Map>) defaultValue(fields()[2]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 000000000000..05fefe4cee75 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 000000000000..00711a0c2a26 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java deleted file mode 100644 index daec65a5bbe5..000000000000 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Autogenerated by Avro - * - * DO NOT EDIT DIRECTLY - */ -package org.apache.spark.sql.parquet.test.avro; - -@SuppressWarnings("all") -@org.apache.avro.specific.AvroGenerated -public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); - - @SuppressWarnings("all") - public interface Callback extends CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; - } -} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java deleted file mode 100644 index 354c9d73cca3..000000000000 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java +++ /dev/null @@ -1,1001 +0,0 @@ -/** - * Autogenerated by Avro - * - * DO NOT EDIT DIRECTLY - */ -package org.apache.spark.sql.parquet.test.avro; -@SuppressWarnings("all") -@org.apache.avro.specific.AvroGenerated -public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); - public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } - @Deprecated public boolean bool_column; - @Deprecated public int int_column; - @Deprecated public long long_column; - @Deprecated public float float_column; - @Deprecated public double double_column; - @Deprecated public java.nio.ByteBuffer binary_column; - @Deprecated public java.lang.String string_column; - @Deprecated public java.lang.Boolean maybe_bool_column; - @Deprecated public java.lang.Integer maybe_int_column; - @Deprecated public java.lang.Long maybe_long_column; - @Deprecated public java.lang.Float maybe_float_column; - @Deprecated public java.lang.Double maybe_double_column; - @Deprecated public java.nio.ByteBuffer maybe_binary_column; - @Deprecated public java.lang.String maybe_string_column; - @Deprecated public java.util.List strings_column; - @Deprecated public java.util.Map string_to_int_column; - @Deprecated public java.util.Map> complex_column; - - /** - * Default constructor. Note that this does not initialize fields - * to their default values from the schema. If that is desired then - * one should use newBuilder(). - */ - public ParquetAvroCompat() {} - - /** - * All-args constructor. - */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { - this.bool_column = bool_column; - this.int_column = int_column; - this.long_column = long_column; - this.float_column = float_column; - this.double_column = double_column; - this.binary_column = binary_column; - this.string_column = string_column; - this.maybe_bool_column = maybe_bool_column; - this.maybe_int_column = maybe_int_column; - this.maybe_long_column = maybe_long_column; - this.maybe_float_column = maybe_float_column; - this.maybe_double_column = maybe_double_column; - this.maybe_binary_column = maybe_binary_column; - this.maybe_string_column = maybe_string_column; - this.strings_column = strings_column; - this.string_to_int_column = string_to_int_column; - this.complex_column = complex_column; - } - - public org.apache.avro.Schema getSchema() { return SCHEMA$; } - // Used by DatumWriter. Applications should not call. - public java.lang.Object get(int field$) { - switch (field$) { - case 0: return bool_column; - case 1: return int_column; - case 2: return long_column; - case 3: return float_column; - case 4: return double_column; - case 5: return binary_column; - case 6: return string_column; - case 7: return maybe_bool_column; - case 8: return maybe_int_column; - case 9: return maybe_long_column; - case 10: return maybe_float_column; - case 11: return maybe_double_column; - case 12: return maybe_binary_column; - case 13: return maybe_string_column; - case 14: return strings_column; - case 15: return string_to_int_column; - case 16: return complex_column; - default: throw new org.apache.avro.AvroRuntimeException("Bad index"); - } - } - // Used by DatumReader. Applications should not call. - @SuppressWarnings(value="unchecked") - public void put(int field$, java.lang.Object value$) { - switch (field$) { - case 0: bool_column = (java.lang.Boolean)value$; break; - case 1: int_column = (java.lang.Integer)value$; break; - case 2: long_column = (java.lang.Long)value$; break; - case 3: float_column = (java.lang.Float)value$; break; - case 4: double_column = (java.lang.Double)value$; break; - case 5: binary_column = (java.nio.ByteBuffer)value$; break; - case 6: string_column = (java.lang.String)value$; break; - case 7: maybe_bool_column = (java.lang.Boolean)value$; break; - case 8: maybe_int_column = (java.lang.Integer)value$; break; - case 9: maybe_long_column = (java.lang.Long)value$; break; - case 10: maybe_float_column = (java.lang.Float)value$; break; - case 11: maybe_double_column = (java.lang.Double)value$; break; - case 12: maybe_binary_column = (java.nio.ByteBuffer)value$; break; - case 13: maybe_string_column = (java.lang.String)value$; break; - case 14: strings_column = (java.util.List)value$; break; - case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; - default: throw new org.apache.avro.AvroRuntimeException("Bad index"); - } - } - - /** - * Gets the value of the 'bool_column' field. - */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** - * Sets the value of the 'bool_column' field. - * @param value the value to set. - */ - public void setBoolColumn(java.lang.Boolean value) { - this.bool_column = value; - } - - /** - * Gets the value of the 'int_column' field. - */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** - * Sets the value of the 'int_column' field. - * @param value the value to set. - */ - public void setIntColumn(java.lang.Integer value) { - this.int_column = value; - } - - /** - * Gets the value of the 'long_column' field. - */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** - * Sets the value of the 'long_column' field. - * @param value the value to set. - */ - public void setLongColumn(java.lang.Long value) { - this.long_column = value; - } - - /** - * Gets the value of the 'float_column' field. - */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** - * Sets the value of the 'float_column' field. - * @param value the value to set. - */ - public void setFloatColumn(java.lang.Float value) { - this.float_column = value; - } - - /** - * Gets the value of the 'double_column' field. - */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** - * Sets the value of the 'double_column' field. - * @param value the value to set. - */ - public void setDoubleColumn(java.lang.Double value) { - this.double_column = value; - } - - /** - * Gets the value of the 'binary_column' field. - */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** - * Sets the value of the 'binary_column' field. - * @param value the value to set. - */ - public void setBinaryColumn(java.nio.ByteBuffer value) { - this.binary_column = value; - } - - /** - * Gets the value of the 'string_column' field. - */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** - * Sets the value of the 'string_column' field. - * @param value the value to set. - */ - public void setStringColumn(java.lang.String value) { - this.string_column = value; - } - - /** - * Gets the value of the 'maybe_bool_column' field. - */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** - * Sets the value of the 'maybe_bool_column' field. - * @param value the value to set. - */ - public void setMaybeBoolColumn(java.lang.Boolean value) { - this.maybe_bool_column = value; - } - - /** - * Gets the value of the 'maybe_int_column' field. - */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** - * Sets the value of the 'maybe_int_column' field. - * @param value the value to set. - */ - public void setMaybeIntColumn(java.lang.Integer value) { - this.maybe_int_column = value; - } - - /** - * Gets the value of the 'maybe_long_column' field. - */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** - * Sets the value of the 'maybe_long_column' field. - * @param value the value to set. - */ - public void setMaybeLongColumn(java.lang.Long value) { - this.maybe_long_column = value; - } - - /** - * Gets the value of the 'maybe_float_column' field. - */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** - * Sets the value of the 'maybe_float_column' field. - * @param value the value to set. - */ - public void setMaybeFloatColumn(java.lang.Float value) { - this.maybe_float_column = value; - } - - /** - * Gets the value of the 'maybe_double_column' field. - */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** - * Sets the value of the 'maybe_double_column' field. - * @param value the value to set. - */ - public void setMaybeDoubleColumn(java.lang.Double value) { - this.maybe_double_column = value; - } - - /** - * Gets the value of the 'maybe_binary_column' field. - */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** - * Sets the value of the 'maybe_binary_column' field. - * @param value the value to set. - */ - public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { - this.maybe_binary_column = value; - } - - /** - * Gets the value of the 'maybe_string_column' field. - */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** - * Sets the value of the 'maybe_string_column' field. - * @param value the value to set. - */ - public void setMaybeStringColumn(java.lang.String value) { - this.maybe_string_column = value; - } - - /** - * Gets the value of the 'strings_column' field. - */ - public java.util.List getStringsColumn() { - return strings_column; - } - - /** - * Sets the value of the 'strings_column' field. - * @param value the value to set. - */ - public void setStringsColumn(java.util.List value) { - this.strings_column = value; - } - - /** - * Gets the value of the 'string_to_int_column' field. - */ - public java.util.Map getStringToIntColumn() { - return string_to_int_column; - } - - /** - * Sets the value of the 'string_to_int_column' field. - * @param value the value to set. - */ - public void setStringToIntColumn(java.util.Map value) { - this.string_to_int_column = value; - } - - /** - * Gets the value of the 'complex_column' field. - */ - public java.util.Map> getComplexColumn() { - return complex_column; - } - - /** - * Sets the value of the 'complex_column' field. - * @param value the value to set. - */ - public void setComplexColumn(java.util.Map> value) { - this.complex_column = value; - } - - /** Creates a new ParquetAvroCompat RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); - } - - /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); - } - - /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); - } - - /** - * RecordBuilder for ParquetAvroCompat instances. - */ - public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase - implements org.apache.avro.data.RecordBuilder { - - private boolean bool_column; - private int int_column; - private long long_column; - private float float_column; - private double double_column; - private java.nio.ByteBuffer binary_column; - private java.lang.String string_column; - private java.lang.Boolean maybe_bool_column; - private java.lang.Integer maybe_int_column; - private java.lang.Long maybe_long_column; - private java.lang.Float maybe_float_column; - private java.lang.Double maybe_double_column; - private java.nio.ByteBuffer maybe_binary_column; - private java.lang.String maybe_string_column; - private java.util.List strings_column; - private java.util.Map string_to_int_column; - private java.util.Map> complex_column; - - /** Creates a new Builder */ - private Builder() { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); - } - - /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - super(other); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); - fieldSetFlags()[0] = true; - } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); - fieldSetFlags()[1] = true; - } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); - fieldSetFlags()[2] = true; - } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } - } - - /** Creates a Builder by copying an existing ParquetAvroCompat instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); - fieldSetFlags()[0] = true; - } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); - fieldSetFlags()[1] = true; - } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); - fieldSetFlags()[2] = true; - } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } - } - - /** Gets the value of the 'bool_column' field */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { - validate(fields()[0], value); - this.bool_column = value; - fieldSetFlags()[0] = true; - return this; - } - - /** Checks whether the 'bool_column' field has been set */ - public boolean hasBoolColumn() { - return fieldSetFlags()[0]; - } - - /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { - fieldSetFlags()[0] = false; - return this; - } - - /** Gets the value of the 'int_column' field */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { - validate(fields()[1], value); - this.int_column = value; - fieldSetFlags()[1] = true; - return this; - } - - /** Checks whether the 'int_column' field has been set */ - public boolean hasIntColumn() { - return fieldSetFlags()[1]; - } - - /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { - fieldSetFlags()[1] = false; - return this; - } - - /** Gets the value of the 'long_column' field */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { - validate(fields()[2], value); - this.long_column = value; - fieldSetFlags()[2] = true; - return this; - } - - /** Checks whether the 'long_column' field has been set */ - public boolean hasLongColumn() { - return fieldSetFlags()[2]; - } - - /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { - fieldSetFlags()[2] = false; - return this; - } - - /** Gets the value of the 'float_column' field */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { - validate(fields()[3], value); - this.float_column = value; - fieldSetFlags()[3] = true; - return this; - } - - /** Checks whether the 'float_column' field has been set */ - public boolean hasFloatColumn() { - return fieldSetFlags()[3]; - } - - /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { - fieldSetFlags()[3] = false; - return this; - } - - /** Gets the value of the 'double_column' field */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { - validate(fields()[4], value); - this.double_column = value; - fieldSetFlags()[4] = true; - return this; - } - - /** Checks whether the 'double_column' field has been set */ - public boolean hasDoubleColumn() { - return fieldSetFlags()[4]; - } - - /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { - fieldSetFlags()[4] = false; - return this; - } - - /** Gets the value of the 'binary_column' field */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[5], value); - this.binary_column = value; - fieldSetFlags()[5] = true; - return this; - } - - /** Checks whether the 'binary_column' field has been set */ - public boolean hasBinaryColumn() { - return fieldSetFlags()[5]; - } - - /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { - binary_column = null; - fieldSetFlags()[5] = false; - return this; - } - - /** Gets the value of the 'string_column' field */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { - validate(fields()[6], value); - this.string_column = value; - fieldSetFlags()[6] = true; - return this; - } - - /** Checks whether the 'string_column' field has been set */ - public boolean hasStringColumn() { - return fieldSetFlags()[6]; - } - - /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { - string_column = null; - fieldSetFlags()[6] = false; - return this; - } - - /** Gets the value of the 'maybe_bool_column' field */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { - validate(fields()[7], value); - this.maybe_bool_column = value; - fieldSetFlags()[7] = true; - return this; - } - - /** Checks whether the 'maybe_bool_column' field has been set */ - public boolean hasMaybeBoolColumn() { - return fieldSetFlags()[7]; - } - - /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { - maybe_bool_column = null; - fieldSetFlags()[7] = false; - return this; - } - - /** Gets the value of the 'maybe_int_column' field */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { - validate(fields()[8], value); - this.maybe_int_column = value; - fieldSetFlags()[8] = true; - return this; - } - - /** Checks whether the 'maybe_int_column' field has been set */ - public boolean hasMaybeIntColumn() { - return fieldSetFlags()[8]; - } - - /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { - maybe_int_column = null; - fieldSetFlags()[8] = false; - return this; - } - - /** Gets the value of the 'maybe_long_column' field */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { - validate(fields()[9], value); - this.maybe_long_column = value; - fieldSetFlags()[9] = true; - return this; - } - - /** Checks whether the 'maybe_long_column' field has been set */ - public boolean hasMaybeLongColumn() { - return fieldSetFlags()[9]; - } - - /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { - maybe_long_column = null; - fieldSetFlags()[9] = false; - return this; - } - - /** Gets the value of the 'maybe_float_column' field */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { - validate(fields()[10], value); - this.maybe_float_column = value; - fieldSetFlags()[10] = true; - return this; - } - - /** Checks whether the 'maybe_float_column' field has been set */ - public boolean hasMaybeFloatColumn() { - return fieldSetFlags()[10]; - } - - /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { - maybe_float_column = null; - fieldSetFlags()[10] = false; - return this; - } - - /** Gets the value of the 'maybe_double_column' field */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { - validate(fields()[11], value); - this.maybe_double_column = value; - fieldSetFlags()[11] = true; - return this; - } - - /** Checks whether the 'maybe_double_column' field has been set */ - public boolean hasMaybeDoubleColumn() { - return fieldSetFlags()[11]; - } - - /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { - maybe_double_column = null; - fieldSetFlags()[11] = false; - return this; - } - - /** Gets the value of the 'maybe_binary_column' field */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[12], value); - this.maybe_binary_column = value; - fieldSetFlags()[12] = true; - return this; - } - - /** Checks whether the 'maybe_binary_column' field has been set */ - public boolean hasMaybeBinaryColumn() { - return fieldSetFlags()[12]; - } - - /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { - maybe_binary_column = null; - fieldSetFlags()[12] = false; - return this; - } - - /** Gets the value of the 'maybe_string_column' field */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { - validate(fields()[13], value); - this.maybe_string_column = value; - fieldSetFlags()[13] = true; - return this; - } - - /** Checks whether the 'maybe_string_column' field has been set */ - public boolean hasMaybeStringColumn() { - return fieldSetFlags()[13]; - } - - /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { - maybe_string_column = null; - fieldSetFlags()[13] = false; - return this; - } - - /** Gets the value of the 'strings_column' field */ - public java.util.List getStringsColumn() { - return strings_column; - } - - /** Sets the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { - validate(fields()[14], value); - this.strings_column = value; - fieldSetFlags()[14] = true; - return this; - } - - /** Checks whether the 'strings_column' field has been set */ - public boolean hasStringsColumn() { - return fieldSetFlags()[14]; - } - - /** Clears the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { - strings_column = null; - fieldSetFlags()[14] = false; - return this; - } - - /** Gets the value of the 'string_to_int_column' field */ - public java.util.Map getStringToIntColumn() { - return string_to_int_column; - } - - /** Sets the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { - validate(fields()[15], value); - this.string_to_int_column = value; - fieldSetFlags()[15] = true; - return this; - } - - /** Checks whether the 'string_to_int_column' field has been set */ - public boolean hasStringToIntColumn() { - return fieldSetFlags()[15]; - } - - /** Clears the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { - string_to_int_column = null; - fieldSetFlags()[15] = false; - return this; - } - - /** Gets the value of the 'complex_column' field */ - public java.util.Map> getComplexColumn() { - return complex_column; - } - - /** Sets the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { - validate(fields()[16], value); - this.complex_column = value; - fieldSetFlags()[16] = true; - return this; - } - - /** Checks whether the 'complex_column' field has been set */ - public boolean hasComplexColumn() { - return fieldSetFlags()[16]; - } - - /** Clears the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { - complex_column = null; - fieldSetFlags()[16] = false; - return this; - } - - @Override - public ParquetAvroCompat build() { - try { - ParquetAvroCompat record = new ParquetAvroCompat(); - record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); - record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); - record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); - record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); - record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); - record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); - record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); - record.maybe_bool_column = fieldSetFlags()[7] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[7]); - record.maybe_int_column = fieldSetFlags()[8] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[8]); - record.maybe_long_column = fieldSetFlags()[9] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[9]); - record.maybe_float_column = fieldSetFlags()[10] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[10]); - record.maybe_double_column = fieldSetFlags()[11] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[11]); - record.maybe_binary_column = fieldSetFlags()[12] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[12]); - record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); - record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); - record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); - return record; - } catch (Exception e) { - throw new org.apache.avro.AvroRuntimeException(e); - } - } - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index cb84e78d628c..7b50aad4ad49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -27,6 +28,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +36,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,14 +49,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { @@ -81,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -93,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -116,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -127,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -163,8 +168,8 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18), + List fields = new ArrayList<>(7); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); @@ -173,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -185,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 2c669bb59a0b..8e0b2dbca4a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,50 +17,60 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.ArrayList; + +import scala.collection.JavaConverters; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + } + + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } /** @@ -95,7 +105,7 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); - + // Varargs with mathfunctions DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); @@ -118,7 +128,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -139,11 +149,7 @@ public List getD() { } } - @Test - public void testCreateDataFrameFromJavaBeans() { - Bean bean = new Bean(); - JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + void validateDataFrameWithBeans(Bean bean, DataFrame df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -160,17 +166,18 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } + @SuppressWarnings("unchecked") Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); + Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -178,7 +185,45 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + @Test + public void testCreateDataFrameFromLocalJavaBeans() { + Bean bean = new Bean(); + List data = Arrays.asList(bean); + DataFrame df = context.createDataFrame(data, Bean.class); + validateDataFrameWithBeans(bean, df); + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + validateDataFrameWithBeans(bean, df); + } + + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + + @Test + public void testCreateStructTypeFromList(){ + List fields1 = new ArrayList<>(); + fields1.add(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema1 = StructType$.MODULE$.apply(fields1); + Assert.assertEquals(0, schema1.fieldIndex("id")); + + List fields2 = Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema2 = StructType$.MODULE$.apply(fields2); + Assert.assertEquals(0, schema2.fieldIndex("id")); + } + + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -191,24 +236,24 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } - + @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -217,22 +262,63 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; - Assert.assertArrayEquals(expected, actual); + Assert.assertEquals(0, actual[0].getLong(0)); + Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); + Assert.assertEquals(1, actual[1].getLong(0)); + Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); + } + + @Test + public void pivot() { + DataFrame df = context.table("courseSales"); + Row[] actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collect(); + + Assert.assertEquals(2012, actual[0].getInt(0)); + Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + + Assert.assertEquals(2013, actual[1].getInt(0)); + Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + } + + public void testGenericLoad() { + DataFrame df1 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + DataFrame df1 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java new file mode 100644 index 000000000000..383a2d0badb5 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -0,0 +1,692 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; +import scala.Tuple5; + +import org.junit.*; + +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaDatasetSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2(t1, t2); + } + + @Test + public void testCollect() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); + } + + @Test + public void testCommonOperation() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + Assert.assertEquals("hello", ds.first()); + + Dataset filtered = ds.filter(new FilterFunction() { + @Override + public boolean call(String v) throws Exception { + return v.startsWith("h"); + } + }); + Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); + + + Dataset mapped = ds.map(new MapFunction() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, Encoders.INT()); + Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); + + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { + @Override + public Iterable call(Iterator it) throws Exception { + List ls = new LinkedList(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase()); + } + return ls; + } + }, Encoders.STRING()); + Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); + + Dataset flatMapped = ds.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) throws Exception { + List ls = new LinkedList(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); + } + return ls; + } + }, Encoders.STRING()); + Assert.assertEquals( + Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), + flatMapped.collectAsList()); + } + + @Test + public void testForeach() { + final Accumulator accum = jsc.accumulator(0); + List data = Arrays.asList("a", "b", "c"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + + ds.foreach(new ForeachFunction() { + @Override + public void call(String s) throws Exception { + accum.add(1); + } + }); + Assert.assertEquals(3, accum.value().intValue()); + } + + @Test + public void testReduce() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, Encoders.INT()); + + int reduced = ds.reduce(new ReduceFunction() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(6, reduced); + } + + @Test + public void testGroupBy() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset grouped = ds.groupBy(new MapFunction() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, Encoders.INT()); + + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { + @Override + public String call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, Encoders.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { + @Override + public Iterable call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()); + } + }, + Encoders.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + + Dataset> reduced = grouped.reduce(new ReduceFunction() { + @Override + public String call(String v1, String v2) throws Exception { + return v1 + v2; + } + }); + + Assert.assertEquals( + Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), + reduced.collectAsList()); + + List data2 = Arrays.asList(2, 6, 10); + Dataset ds2 = context.createDataset(data2, Encoders.INT()); + GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { + @Override + public Integer call(Integer v) throws Exception { + return v / 2; + } + }, Encoders.INT()); + + Dataset cogrouped = grouped.cogroup( + grouped2, + new CoGroupFunction() { + @Override + public Iterable call( + Integer key, + Iterator left, + Iterator right) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); + } + return Collections.singletonList(sb.toString()); + } + }, + Encoders.STRING()); + + Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + } + + @Test + public void testGroupByColumn() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset grouped = + ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); + + Dataset mapped = grouped.mapGroups( + new MapGroupsFunction() { + @Override + public String call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return sb.toString(); + } + }, + Encoders.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + } + + @Test + public void testSelect() { + List data = Arrays.asList(2, 6); + Dataset ds = context.createDataset(data, Encoders.INT()); + + Dataset> selected = ds.select( + expr("value + 1"), + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); + + Assert.assertEquals( + Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), + selected.collectAsList()); + } + + @Test + public void testSetOperation() { + List data = Arrays.asList("abc", "abc", "xyz"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + + Assert.assertEquals( + Arrays.asList("abc", "xyz"), + sort(ds.distinct().collectAsList().toArray(new String[0]))); + + List data2 = Arrays.asList("xyz", "foo", "foo"); + Dataset ds2 = context.createDataset(data2, Encoders.STRING()); + + Dataset intersected = ds.intersect(ds2); + Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); + + Dataset unioned = ds.union(ds2); + Assert.assertEquals( + Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), + sort(unioned.collectAsList().toArray(new String[0]))); + + Dataset subtracted = ds.subtract(ds2); + Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + } + + private > List sort(T[] data) { + Arrays.sort(data); + return Arrays.asList(data); + } + + @Test + public void testJoin() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); + List data2 = Arrays.asList(2, 3, 4); + Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + + Dataset> joined = + ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); + Assert.assertEquals( + Arrays.asList(tuple2(2, 2), tuple2(3, 3)), + joined.collectAsList()); + } + + @Test + public void testTupleEncoder() { + Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); + Dataset> ds2 = context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + Encoder> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); + List> data3 = + Arrays.asList(new Tuple3(1, 2L, "a")); + Dataset> ds3 = context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + + Encoder> encoder4 = + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); + List> data4 = + Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Dataset> ds4 = context.createDataset(data4, encoder4); + Assert.assertEquals(data4, ds4.collectAsList()); + + Encoder> encoder5 = + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); + List> data5 = + Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Dataset> ds5 = + context.createDataset(data5, encoder5); + Assert.assertEquals(data5, ds5.collectAsList()); + } + + @Test + public void testNestedTupleEncoder() { + // test ((int, string), string) + Encoder, String>> encoder = + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); + List, String>> data = + Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); + Dataset, String>> ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + + // test (int, (string, string, long)) + Encoder>> encoder2 = + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); + List>> data2 = + Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Dataset>> ds2 = + context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + // test (int, ((string, long), string)) + Encoder, String>>> encoder3 = + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); + List, String>>> data3 = + Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); + Dataset, String>>> ds3 = + context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + } + + @Test + public void testPrimitiveEncoder() { + Encoder> encoder = + Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), + Encoders.FLOAT()); + List> data = + Arrays.asList(new Tuple5( + 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), + Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); + Dataset> ds = + context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testTypedAggregation() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + GroupedDataset> grouped = ds.groupBy( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + + Dataset> agged = + grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); + Assert.assertEquals( + Arrays.asList( + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator, Integer, Integer> { + + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2 t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } + + public class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + private List f; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + if (!e.equals(that.e)) return false; + return f.equals(that.f); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); + return result; + } + } + + public class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1, 2}); + obj1.setD(new String[]{"hello", null}); + obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{null, "world"}); + obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26..3ab4db2a035d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff37..4a78dca7fea6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,15 +20,16 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +41,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") @@ -57,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -77,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28a..9e241f20987c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -43,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,8 +53,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -62,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -71,9 +73,16 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -82,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..cfd7889b4ac2 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/dec-in-fixed-len.parquet new file mode 100644 index 000000000000..6ad37d563951 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-fixed-len.parquet differ diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/dec-in-i32.parquet new file mode 100755 index 000000000000..bb5d4af8dd36 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-i32.parquet differ diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/dec-in-i64.parquet new file mode 100755 index 000000000000..e07c4a0ad984 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-i64.parquet differ diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/sql/core/src/test/resources/nested-array-struct.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 000000000000..520922f73ebb Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-int.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/old-repeated-message.parquet new file mode 100644 index 000000000000..548db9916277 Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-message.parquet differ diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 000000000000..213f1a90291b Binary files /dev/null and b/sql/core/src/test/resources/old-repeated.parquet differ diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet old mode 100755 new mode 100644 diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-string.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 000000000000..c29eee35c350 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-struct.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/proto-struct-with-array-many.parquet new file mode 100644 index 000000000000..ff9809675fc0 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array-many.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/proto-struct-with-array.parquet new file mode 100644 index 000000000000..325a8370ad20 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array.parquet differ diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/text-suite.txt new file mode 100644 index 000000000000..e8fd967197fe --- /dev/null +++ b/sql/core/src/test/resources/text-suite.txt @@ -0,0 +1,4 @@ +This is a test file for the text data source +1+1 +数据砖头 +"doh" diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 000000000000..f9d498c80493 --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index eb3e91332206..d86df4cfb9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,27 +17,28 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.PhysicalRDD + import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) - -class CachedTableSuite extends QueryTest { - TestData // Load test tables. +private case class BigData(s: String) - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -47,47 +48,66 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + } + + test("withColumn doesn't invalidate cached dataframe") { + var evalCount = 0 + val myUDF = udf((x: String) => { evalCount += 1; "result" }) + val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) + df.cache() + + df.collect() + assert(evalCount === 1) + + df.collect() + assert(evalCount === 1) + + val df2 = df.withColumn("newColumn", lit(1)) + df2.collect() + + // We should not reevaluate the cached dataframe + assert(evalCount === 1) } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + sqlContext.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + sqlContext.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + sqlContext.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -95,103 +115,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 1000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(sqlContext.table("bigData").count() === 200000L) + sqlContext.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + sqlContext.table("testData").cache() + assertCached(sqlContext.table("testData")) + sqlContext.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + sqlContext.table("testData").cache() + sqlContext.table("testData").count() + sqlContext.table("testData").unpersist(blocking = true) + assertCached(sqlContext.table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(sqlContext.table("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + sqlContext.uncacheTable("testData") + assert(!sqlContext.isCached("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + assertCached(sqlContext.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + sqlContext.uncacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + sqlContext.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + sqlContext.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -199,7 +219,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -208,14 +228,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -223,14 +243,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -238,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -250,7 +270,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,59 +278,55 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + sqlContext.table("t1") + sqlContext.dropTempTable("t1") + intercept[NoSuchTableException](sqlContext.table("t1")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + sqlContext.cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(sqlContext.isCached("t1")) + assert(sqlContext.isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + sqlContext.dropTempTable("t1") + intercept[NoSuchTableException](sqlContext.table("t1")) + assert(!sqlContext.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + sqlContext.clearCache() + assert(sqlContext.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(sqlContext.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - Accumulators.synchronized { - val accsSize = Accumulators.originals.size - ctx.cacheTable("t1") - ctx.cacheTable("t2") - assert((accsSize + 2) == Accumulators.originals.size) - } + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() @@ -319,9 +335,175 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } + + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { + sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") + sqlContext.cacheTable("abc") + + val sparkPlan = sql( + """select a.key, b.key, c.key from + |abc a join abc b on a.key=b.key + |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan + + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) + assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) + } + + /** + * Verifies that the plan for `df` contains `expected` number of Exchange operators. + */ + private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + } + + test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { + val table3x = testData.unionAll(testData).unionAll(testData) + table3x.registerTempTable("testData3x") + + sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") + sqlContext.cacheTable("orderedTable") + assertCached(sqlContext.table("orderedTable")) + // Should not have an exchange as the query is already sorted on the group by key. + verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) + checkAnswer( + sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), + sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) + sqlContext.uncacheTable("orderedTable") + sqlContext.dropTempTable("orderedTable") + + // Set up two tables distributed in the same way. Try this with the data distributed into + // different number of partitions. + for (numPartitions <- 1 until 10 by 4) { + withTempTable("t1", "t2") { + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // One side of join is not partitioned in the desired way. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(12, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // One side of join is not partitioned in the desired way. Since the number of partitions of + // the side that has already partitioned is smaller than the side that is not partitioned, + // we shuffle both side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 2) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // repartition's column ordering is different from group by column ordering. + // But they use the same set of columns. + withTempTable("t1") { + testData.repartition(6, $"value", $"key").registerTempTable("t1") + sqlContext.cacheTable("t1") + + val query = sql("SELECT value, key from t1 group by key, value") + verifyNumExchanges(query, 0) + checkAnswer( + query, + testData.distinct().select($"value", $"key")) + sqlContext.uncacheTable("t1") + } + + // repartition's column ordering is different from join condition's column ordering. + // We will still shuffle because hashcodes of a row depend on the column ordering. + // If we do not shuffle, we may actually partition two tables in totally two different way. + // See PartitioningSuite for more details. + withTempTable("t1", "t2") { + val df1 = testData + df1.repartition(6, $"value", $"key").registerTempTable("t1") + val df2 = testData2.select($"a", $"b".cast("string")) + df2.repartition(6, $"a", $"b").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = + sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 35ca0b4c7cc2..38c0eb589f96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,20 +17,93 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.{Project, TungstenProject} +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + private lazy val booleanData = { + sqlContext.createDataFrame(sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } + + test("column names with space") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot") + + checkAnswer( + df.select(df("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select($"name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(col("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select("name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(expr("`name with space`")), + Row(1) :: Nil) + } + + test("column names with dot") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a") + + checkAnswer( + df.select(df("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"`name.with.dot`"), + Row("a") :: Nil) - override def sqlContext(): SQLContext = ctx + checkAnswer( + df.select(col("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select(df("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("a.`name.with.dot`")), + Row("a") :: Nil) + } test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") @@ -38,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( @@ -190,7 +271,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -200,12 +281,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } test("isNaN") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -221,12 +302,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } test("nanvl") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -241,7 +322,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -269,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -287,6 +368,17 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { checkAnswer( nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) + + val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + Row("abc") :: + Row(null) :: + Row("xyz") :: Nil), + StructType(Seq(StructField("a", StringType, true)))) + + checkAnswer( + nullData2.filter($"a" <=> null), + Row(null) :: Nil) + } test(">") { @@ -330,7 +422,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -345,26 +437,25 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("in") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".in(1, 2)), + checkAnswer(df.filter($"a".isin(1, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 2)), + checkAnswer(df.filter($"a".isin(3, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 1)), + checkAnswer(df.filter($"a".isin(3, 1)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".in("y", "x")), + checkAnswer(df.filter($"b".isin("y", "x")), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "x")), + checkAnswer(df.filter($"b".isin("z", "x")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "y")), + checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) - } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + intercept[AnalysisException] { + df2.filter($"a".isin($"b")) + } + } test("&&") { checkAnswer( @@ -449,7 +540,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -470,24 +561,28 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( df.select(monotonicallyIncreasingId()), Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil ) + checkAnswer( + df.select(expr("monotonically_increasing_id()")), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) } test("sparkPartitionId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -498,7 +593,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("InputFileName") { withTempPath { dir => - val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) .head.getString(0) @@ -508,12 +603,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } } - test("lift alias out of cast") { - compareExpressions( - col("1234").as("name").cast("int").expr, - col("1234").cast("int").as("name").expr) - } - test("columns can be compared") { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) @@ -541,8 +630,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { - case project: Project => project - case tungstenProject: TungstenProject => tungstenProject + case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff7440a76..b1004bc5bc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( @@ -62,18 +61,89 @@ class DataFrameAggregateSuite extends QueryTest { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { @@ -85,13 +155,8 @@ class DataFrameAggregateSuite extends QueryTest { test("average") { checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) - - // Also check mean - checkAnswer( - testData2.agg(mean('a)), - Row(2.0)) + testData2.agg(avg('a), mean('a)), + Row(2.0, 2.0)) checkAnswer( testData2.agg(avg('a), sumDistinct('a)), // non-partial @@ -100,6 +165,7 @@ class DataFrameAggregateSuite extends QueryTest { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) + checkAnswer( decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) @@ -168,15 +234,55 @@ class DataFrameAggregateSuite extends QueryTest { ) } + test("multiple column distinct count") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkAnswer( + df1.agg(countDistinct('key1, 'key2)), + Row(3) + ) + + checkAnswer( + df1.agg(countDistinct('key1, 'key2, 'key3)), + Row(3) + ) + + checkAnswer( + df1.groupBy('key1).agg(countDistinct('key2, 'key3)), + Seq(Row("a", 2), Row("x", 1)) + ) + } + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - checkAnswer( emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4.0 / 5.0) + checkAnswer( + testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + checkAnswer( + testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(null, null, null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( @@ -190,4 +296,62 @@ class DataFrameAggregateSuite extends QueryTest { emptyTableData.agg(sumDistinct('a)), Row(null)) } + + test("moments") { + val absTol = 1e-8 + + val sparkVariance = testData2.agg(variance('a)) + checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) + + val sparkVariancePop = testData2.agg(var_pop('a)) + checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) + + val sparkVarianceSamp = testData2.agg(var_samp('a)) + checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol) + + val sparkSkewness = testData2.agg(skewness('a)) + checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol) + + val sparkKurtosis = testData2.agg(kurtosis('a)) + checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) + } + + test("zero moments") { + val input = Seq((1, 2)).toDF("a", "b") + checkAnswer( + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) + + checkAnswer( + input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) + } + + test("null moments") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + + checkAnswer( + emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(null, null, null, null, null)) + + checkAnswer( + emptyTableData.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(null, null, null, null, null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 000000000000..09f7b507670c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 431dcf7382f1..aff9efe4b2b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -208,13 +206,21 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } + test("string function find_in_set") { + val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") + + checkAnswer( + df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"), + Row(3, 0)) + } + test("conditional function: least") { checkAnswer( testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -225,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } @@ -302,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest { Row(null, null)) ) - val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") - assert(intercept[AnalysisException] { - df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("does not support sorting array of type array")) + val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") + checkAnswer( + df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), + Seq( + Row( + Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)), + Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) + ) val df3 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { @@ -315,9 +325,9 @@ class DataFrameFunctionsSuite extends QueryTest { test("array size function") { val df = Seq( - (Array[Int](1, 2), "x"), - (Array[Int](), "y"), - (Array[Int](1, 2, 3), "z") + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") ).toDF("a", "b") checkAnswer( df.select(size($"a")), @@ -344,4 +354,41 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(0), Row(3)) ) } + + test("array contains function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df.select(array_contains(df("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + checkAnswer( + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) + ) + checkAnswer( + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a494..094efbaeadcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c706242d..39a65413bd59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -44,6 +42,31 @@ class DataFrameJoinSuite extends QueryTest { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "inner"), + Row(1, "1", 2, 3) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "outer"), + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "left_semi"), + Row(1, "1", 2) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -59,7 +82,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } @@ -109,5 +132,12 @@ class DataFrameJoinSuite extends QueryTest { // planner should not crash without a join broadcast(df1).queryExecution.executedPlan + + // SPARK-12275: no physical plan for BroadcastHint in some condition + withTempPath { path => + df1.write.parquet(path.getCanonicalPath) + val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + assert(df1.join(broadcast(pf1)).count() === 4) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44ee2c7..e34875471f09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( @@ -141,24 +141,26 @@ class DataFrameNaFunctionsSuite extends QueryTest { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( - (null, null, null, null)).toDF("a", "b", "c", "d") + val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false )), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) // Test Java version checkAnswer( - df.na.fill(mapAsJavaMap(Map( + df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 - ))), - Row("test", null, 1, 2.2)) + "d" -> 2.2, + "e" -> false + ).asJava), + Row("test", null, 1, 2.2, false)) } test("replace") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 000000000000..bc1a336ea4fd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ + + test("pivot courses with literals") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year with literals") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with literals and multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year with no values") { + checkAnswer( + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot max values enforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } + + test("pivot with UnresolvedFunction") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg("earnings" -> "sum"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 07a675e64f52..b15af42caa3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,17 +19,49 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(3, 17, 27, 58, 62).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") @@ -123,19 +155,37 @@ class DataFrameStatSuite extends QueryTest { val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head - items.getSeq[Int](0) should contain (1) - items.getSeq[String](1) should contain (toLetter(1)) + assert(items.getSeq[Int](0).contains(1)) + assert(items.getSeq[String](1).contains(toLetter(1))) val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head - items2.getSeq[Double](0) should contain (-1.0) + assert(items2.getSeq[Double](0).contains(-1.0)) + } + + test("Frequent Items 2") { + val rows = sparkContext.parallelize(Seq.empty[Int], 4) + // this is a regression test, where when merging partitions, we omitted values with higher + // counts than those that existed in the map when the map was full. This test should also fail + // if anything like SPARK-9614 is observed once again + val df = rows.mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { // must come from one of the later merges, therefore higher partition index + Iterator("3", "3", "3", "3", "3") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + val results = df.stat.freqItems(Array("a"), 0.25) + val items = results.collect().head.getSeq[String](0) + assert(items.contains("3")) + assert(items.length === 1) } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 5), Row(1, 8))) + Seq(Row(0, 6), Row(1, 11))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index aef940a52667..0644bdaaa35c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,19 +22,19 @@ import java.io.File import scala.language.postfixOps import scala.util.Random +import org.scalatest.Matchers._ + +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} - -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. @@ -109,6 +109,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(testData.head(2).head.schema === testData.schema) } + test("dataframe alias") { + val df = Seq(Tuple1(1)).toDF("c").as("t") + val dfAlias = df.alias("t2") + df.col("t.c") + dfAlias.col("t2.c") + } + test("simple explode") { val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") @@ -134,6 +141,21 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains( + "Cannot explode *, explode can only be applied on a specific column.")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + test("explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") @@ -155,9 +177,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("filterExpr") { - checkAnswer( - testData.filter("key > 90"), - testData.collect().filter(_.getInt(0) > 90).toSeq) + val res = testData.collect().filter(_.getInt(0) > 90).toSeq + checkAnswer(testData.filter("key > 90"), res) + checkAnswer(testData.filter("key > 9.0e1"), res) + checkAnswer(testData.filter("key > .9e+2"), res) + checkAnswer(testData.filter("key > 0.9e+2"), res) + checkAnswer(testData.filter("key > 900e-1"), res) + checkAnswer(testData.filter("key > 900.0E-1"), res) + checkAnswer(testData.filter("key > 9.e+1"), res) } test("filterExpr using where") { @@ -336,7 +363,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("replace column using withColumn") { - val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -351,6 +378,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("value")) } + test("drop columns using drop") { + val src = Seq((0, 2, 3)).toDF("a", "b", "c") + val df = src.drop("a", "b") + checkAnswer(df, Row(3)) + assert(df.schema.map(_.name) === Seq("c")) + } + test("drop unknown column (no-op)") { val df = testData.drop("random") checkAnswer( @@ -377,13 +411,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("key", "value")) } - test("drop unknown column with same name (no-op) with column reference") { + test("drop unknown column with same name with column reference") { val col = Column("key") val df = testData.drop(col) checkAnswer( df, - testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key", "value")) + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) } test("drop column after join with duplicate columns using column reference") { @@ -415,23 +449,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("randomSplit") { - val n = 600 - val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.collect().toList, "incomplete or wrong split") - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } - } - test("describe") { val describeTestData = Seq( ("Bob", 16, 176), @@ -442,7 +459,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) @@ -487,20 +504,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("inputFiles") { - val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), - Some(testData.schema), None, Map.empty)(sqlContext) - val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) - assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") - val fakeRelation2 = new JSONRelation("/json/path", 1, Some(testData.schema), sqlContext) - val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) - assert(df2.inputFiles.toSet == fakeRelation2.path.toSet) + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) - val unionDF = df1.unionAll(df2) - assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) - val filtered = df1.filter("false").unionAll(df2.intersect(df2)) - assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } } ignore("show") { @@ -511,7 +531,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -565,6 +585,21 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.showString(10) === expectedAnswer) } + test("showString: binary") { + val df = Seq( + ("12".getBytes, "ABC.".getBytes), + ("34".getBytes, "12346".getBytes) + ).toDF() + val expectedAnswer = """+-------+----------------+ + || _1| _2| + |+-------+----------------+ + ||[31 32]| [41 42 43 2E]| + ||[33 34]|[31 32 33 34 36]| + |+-------+----------------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), @@ -601,18 +636,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899: type should match when using codegen") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - } + checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -624,14 +655,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -651,7 +682,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7324 dropDuplicates") { - val testData = sqlContext.sparkContext.parallelize( + val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -685,18 +716,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { Seq(Row(2, 1, 2), Row(1, 1, 1))) } - test("SPARK-7276: Project collapse for continuous select") { - var df = testData - for (i <- 1 to 5) { - df = df.select($"*") - } - - import org.apache.spark.sql.catalyst.plans.logical.Project - // make sure df have at most two Projects - val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] - assert(!p.child.isInstanceOf[Project]) - } - test("SPARK-7150 range api") { // numSlice is greater than length val res1 = sqlContext.range(0, 10, 1, 15).select("id") @@ -843,31 +862,16 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) // We will reuse the same Expression object for LocalRelation. - val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) - assert(df.showString(5) == df.showString(5)) + val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df1.showString(5) == df1.showString(5)) } test("SPARK-8609: local DataFrame with random columns should return same value after sort") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) // We will reuse the same Expression object for LocalRelation. val df = (1 to 10).map(Tuple1.apply).toDF() @@ -884,4 +888,286 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) assert(expected === actual) } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } + + test("SPARK-9950: correctly analyze grouping/aggregating on struct fields") { + val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") + checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) + } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } + + test("SPARK-10185: Read multiple Hadoop Filesystem paths and paths with a comma in it") { + withTempDir { dir => + val df1 = Seq((1, 22)).toDF("a", "b") + val dir1 = new File(dir, "dir,1").getCanonicalPath + df1.write.format("json").save(dir1) + + val df2 = Seq((2, 23)).toDF("a", "b") + val dir2 = new File(dir, "dir2").getCanonicalPath + df2.write.format("json").save(dir2) + + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), + Row(1, 22) :: Row(2, 23) :: Nil) + + checkAnswer(sqlContext.read.format("json").load(dir1), + Row(1, 22) :: Nil) + } + } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } + + test("SPARK-10743: keep the name of expression if possible when do cast") { + val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") + assert(df.select($"src.i".cast(StringType)).columns.head === "i") + } + + test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) + val df = sqlContext.read.parquet(path.getAbsolutePath) + checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) + } + } + } + + /** + * Verifies that there is no Exchange between the Aggregations for `df` + */ + private def verifyNonExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + atFirstAgg = !atFirstAgg + } + case _ => { + if (atFirstAgg) { + fail("Should not have operators between the two aggregations") + } + } + } + } + + /** + * Verifies that there is an Exchange between the Aggregations for `df` + */ + private def verifyExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + if (atFirstAgg) { + fail("Should not have back to back Aggregates") + } + atFirstAgg = true + } + case e: Exchange => atFirstAgg = false + case _ => + } + } + + test("distributeBy and localSort") { + val original = testData.repartition(1) + assert(original.rdd.partitions.length == 1) + val df = original.repartition(5, $"key") + assert(df.rdd.partitions.length == 5) + checkAnswer(original.select(), df.select()) + + val df2 = original.repartition(10, $"key") + assert(df2.rdd.partitions.length == 10) + checkAnswer(original.select(), df2.select()) + + // Group by the column we are distributed by. This should generate a plan with no exchange + // between the aggregates + val df3 = testData.repartition($"key").groupBy("key").count() + verifyNonExchangingAgg(df3) + verifyNonExchangingAgg(testData.repartition($"key", $"value") + .groupBy("key", "value").count()) + + // Grouping by just the first distributeBy expr, need to exchange. + verifyExchangingAgg(testData.repartition($"key", $"value") + .groupBy("key").count()) + + val data = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData2(i % 10, i))).toDF() + + // Distribute and order by. + val df4 = data.repartition($"a").sortWithinPartitions($"b".desc) + // Walk each partition and verify that it is sorted descending and does not contain all + // the values. + df4.rdd.foreachPartition { p => + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v + } + if (allSequential) throw new SparkException("Partition should not be globally ordered") + } + + // Distribute and order by with multiple order bys + val df5 = data.repartition(2, $"a").sortWithinPartitions($"b".asc, $"a".asc) + // Walk each partition and verify that it is sorted ascending + df5.rdd.foreachPartition { p => + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + } + if (allSequential) throw new SparkException("Partition should not be all sequential") + } + + // Distribute into one partition and order by. This partition should contain all the values. + val df6 = data.repartition(1, $"a").sortWithinPartitions("b") + // Walk each partition and verify that it is sorted ascending and not globally sorted. + df6.rdd.foreachPartition { p => + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + } + if (!allSequential) throw new SparkException("Partition should contain all sequential values") + } + } + + test("fix case sensitivity of partition by") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val p = path.getAbsolutePath + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) + checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + } + } + } + + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + + test("SPARK-10656: completely support special chars") { + val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") + checkAnswer(df.select(df("*")), Row(1, "a")) + checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) + } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = sparkContext.parallelize(Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") + + // passing null into the UDF that could handle it + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) -10 else null + } + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a97bc6..68e99d6a6b81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,58 +27,50 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("test simple types") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") - assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) - } + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } test("test struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val struct = Row(1, 2L, 3.0F, 3.0) - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sparkContext.parallelize(Seq(Row(1, struct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(struct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) } test("test nested struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val innerStruct = Row(1, "abcd") - val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType) - .add("b5", new StructType() - .add("b5a", IntegerType) - .add("b5b", StringType)) - .add("b6", StringType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(outerStruct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala similarity index 57% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index c177cbdd991c..b50d7604e0ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.types.{DataType, LongType, StructType} -class HiveDataFrameWindowSuite extends QueryTest { +class DataFrameWindowSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -54,10 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lead(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) } test("lag") { @@ -67,10 +64,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) } test("lead with default value") { @@ -80,10 +74,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( - """SELECT - | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) } test("lag with default value") { @@ -93,10 +84,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) } test("rank functions in unspecific window") { @@ -111,78 +99,52 @@ class HiveDataFrameWindowSuite extends QueryTest { count("key").over(Window.partitionBy("value").orderBy("key")), sum("key").over(Window.partitionBy("value").orderBy("key")), ntile(2).over(Window.partitionBy("value").orderBy("key")), - rowNumber().over(Window.partitionBy("value").orderBy("key")), - denseRank().over(Window.partitionBy("value").orderBy("key")), + row_number().over(Window.partitionBy("value").orderBy("key")), + dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), - cumeDist().over(Window.partitionBy("value").orderBy("key")), - percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( - s"""SELECT - |key, - |max(key) over (partition by value order by key), - |min(key) over (partition by value order by key), - |avg(key) over (partition by value order by key), - |count(key) over (partition by value order by key), - |sum(key) over (partition by value order by key), - |ntile(2) over (partition by value order by key), - |row_number() over (partition by value order by key), - |dense_rank() over (partition by value order by key), - |rank() over (partition by value order by key), - |cume_dist() over (partition by value order by key), - |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect()) + cume_dist().over(Window.partitionBy("value").orderBy("key")), + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) } - test("aggregation and range betweens") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + test("aggregation and range between") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), + Row(2.0d), Row(2.0d))) } - test("aggregation and rows betweens with unbounded") { + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) - | FROM window_table""".stripMargin).collect()) + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), + Row(4, 4, 4, 4))) } - test("aggregation and range betweens with unbounded") { + test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -199,18 +161,12 @@ class HiveDataFrameWindowSuite extends QueryTest { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) - | FROM window_table""".stripMargin).collect()) + Seq(Row(3, null, 3.0d, 4.0d, 3.0d), + Row(5, false, 4.0d, 5.0d, 5.0d), + Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), + Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), + Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), + Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) } test("reverse sliding range frame") { @@ -253,6 +209,87 @@ class HiveDataFrameWindowSuite extends QueryTest { sum($"value").over(window.rangeBetween(1, Long.MaxValue))), Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("statistical functions") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.partitionBy($"key") + checkAnswer( + df.select( + $"key", + var_pop($"value").over(window), + var_samp($"value").over(window), + approxCountDistinct($"value").over(window)), + Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + } + + test("window function with aggregates") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.groupBy($"key") + .agg( + sum($"value"), + sum(sum($"value")).over(window) - sum($"value")), + Seq(Row("a", 6, 9), Row("b", 9, 6))) + } + + test("window function with udaf") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) + } + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + udaf($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 000000000000..c6d2bf07b280 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + + override def finish(reduction: N): N = reduction +} + +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(reduction: (Long, Long)): (Long, Long) = reduction +} + +case class AggData(a: Int, b: String) +object ClassInputAgg extends Aggregator[AggData, Int, Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: Int = 0 + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: Int, a: AggData): Int = b + a.a + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: Int): Int = reduction + + /** + * Merge two intermediate values + */ + override def merge(b1: Int, b2: Int): Int = b1 + b2 +} + +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum(_._2)), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Long], + count("*")), + ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) + } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } + + test("typed aggregation: class input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 3) + } + + test("typed aggregation: class input with reordering") { + val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 1) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), + (1.0, 1)) + + checkAnswer( + ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + ("one", 1)) + } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 000000000000..3a283a4e1f61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala new file mode 100644 index 000000000000..f75d0961823c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +case class IntClass(value: Int) + +class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(1, 2, 3, 4, 5, 6) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(1, 2, 3).toDS().as[IntClass] + checkAnswer( + ds, + IntClass(1), IntClass(2), IntClass(3)) + + assert(ds.collect().head == IntClass(1)) + } + + test("map") { + val ds = Seq(1, 2, 3).toDS() + checkAnswer( + ds.map(_ + 1), + 2, 3, 4) + } + + test("filter") { + val ds = Seq(1, 2, 3, 4).toDS() + checkAnswer( + ds.filter(_ % 2 == 0), + 2, 4) + } + + test("foreach") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(acc += _) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(acc +=)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.reduce(_ + _) == 6) + } + + test("groupBy function, keys") { + val ds = Seq(1, 2, 3, 4, 5).toDS() + val grouped = ds.groupBy(_ % 2) + checkAnswer( + grouped.keys, + 0, 1) + } + + test("groupBy function, map") { + val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() + val grouped = ds.groupBy(_ % 2) + val agged = grouped.mapGroups { case (g, iter) => + val name = if (g == 0) "even" else "odd" + (name, iter.size) + } + + checkAnswer( + agged, + ("even", 5), ("odd", 6)) + } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala new file mode 100644 index 000000000000..f1b6b98dc160 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.{ObjectInput, ObjectOutput, Externalizable} + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + checkAnswer( + data.toDS(), + data: _*) + } + + test("toDS with RDD") { + val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() + checkAnswer( + ds.mapPartitions(_ => Iterator(1)), + 1, 1, 1) + } + + test("collect, first, and take should use encoders for serialization") { + val item = NonSerializableCaseClass("abcd") + val ds = Seq(item).toDS() + assert(ds.collect().head == item) + assert(ds.collectAsList().get(0) == item) + assert(ds.first() == item) + assert(ds.take(1).head == item) + assert(ds.takeAsList(1).get(0) == item) + } + + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkAnswer( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkAnswer( + ds.coalesce(1), + data: _*) + } + + test("as tuple") { + val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") + checkAnswer( + data.as[(String, Int)], + ("a", 1), ("b", 2)) + } + + test("as case class / collect") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + checkAnswer( + ds, + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + assert(ds.collect().head == ClassData("a", 1)) + } + + test("as case class - reordered fields by name") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) + } + + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + + test("map") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.map(v => (v._1, v._2 + 1)), + ("a", 2), ("b", 3), ("c", 4)) + } + + test("map and group by with class data") { + // We inject a group by here to make sure this test case is future proof + // when we implement better pipelining and local execution mode. + val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() + .map(c => ClassData(c.a, c.b + 1)) + .groupBy(p => p).count() + + checkAnswer( + ds, + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) + } + + test("select") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select(expr("_2 + 1").as[Int]), + 2, 3, 4) + } + + test("select 2") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("_2").as[Int]) : Dataset[(String, Int)], + ("a", 1), ("b", 2), ("c", 3)) + } + + test("select 2, primitive and tuple") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("struct(_2, _2)").as[(Int, Int)]), + ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3))) + } + + test("select 2, primitive and class") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) + } + + test("select 2, primitive and class, fields reordered") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkDecoding( + ds.select( + expr("_1").as[String], + expr("named_struct('b', _2, 'a', _1)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) + } + + test("filter") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.filter(_._1 == "b"), + ("b", 2)) + } + + test("foreach") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(v => acc += v._2) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(v => acc += v._2)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("joinWith, flat schema") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + checkAnswer( + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), + (1, 1), (2, 2)) + } + + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) + } + + test("joinWith tuple with primitive, expression") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(("a", 1), ("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"_2"), + (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) + } + + test("joinWith class with primitive, toDF") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"), + Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) + } + + test("multi-level joinWith") { + val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") + val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") + val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") + + checkAnswer( + ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), + ((("a", 1), ("a", 1)), ("a", 1)), + ((("b", 2), ("b", 2)), ("b", 2))) + } + + test("groupBy function, keys") { + val ds = Seq(("a", 1), ("b", 1)).toDS() + val grouped = ds.groupBy(v => (1, v._2)) + checkAnswer( + grouped.keys, + (1, 1)) + } + + test("groupBy function, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy function, flatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMapGroups { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) + } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy function, reduce") { + val ds = Seq("abc", "xyz", "hello").toDS() + val agged = ds.groupBy(_.length).reduce(_ + _) + + checkAnswer( + agged, + 3 -> "abcxyz", 5 -> "hello") + } + + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + + test("groupBy columns, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1") + val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + + test("groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns asKey tuple, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + + checkAnswer( + agged, + (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) + } + + test("groupBy columns asKey class, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + + checkAnswer( + agged, + (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) + } + + test("typed aggregation: expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) + } + + test("typed aggregation: expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) + } + + test("cogroup") { + val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() + val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") + } + + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + } + + test("self join") { + val ds = Seq("1", "2").toDS().as("a") + val joined = ds.joinWith(ds, lit(true)) + checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + } + + test("toString") { + val ds = Seq((1, 2)).toDS() + assert(ds.toString == "[_1: int, _2: int]") + } + + test("showString: Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + val expectedAnswer = """+-----------+ + || value| + |+-----------+ + ||KryoData(1)| + ||KryoData(2)| + |+-----------+ + |""".stripMargin + assert(ds.showString(10) === expectedAnswer) + } + + test("Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((KryoData(1), 1L), (KryoData(2), 1L))) + } + + test("Kryo encoder self join") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (KryoData(1), KryoData(1)), + (KryoData(1), KryoData(2)), + (KryoData(2), KryoData(1)), + (KryoData(2), KryoData(2)))) + } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + test("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } + + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2].collect() + } + assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + } +} + +case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) +case class ClassNullableData(a: String, b: Integer) + +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + +/** Used to test Kryo encoder. */ +class KryoData(val a: Int) { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[KryoData].a + } + override def hashCode: Int = a + override def toString: String = s"KryoData($a)" +} + +object KryoData { + def apply(a: Int): KryoData = new KryoData(a) +} + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897caf952a..a61c3aa48a73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,32 +22,37 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } - // This is a bad test. SPARK-9196 will fix it and re-enable it. - ignore("function current_timestamp") { + test("function current_timestamp and now") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) + + // Current timestamp should return the current timestamp ... + val before = System.currentTimeMillis + val got = sql("SELECT CURRENT_TIMESTAMP()").collect().head.getTimestamp(0).getTime + val after = System.currentTimeMillis + assert(got >= before && got <= after) + + // Now alias + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") @@ -443,6 +448,30 @@ class DateFunctionsSuite extends QueryTest { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) + } + + test("to_unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) } test("datediff") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala new file mode 100644 index 000000000000..78a98798eff6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.{Row, Strategy, QueryTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.UTF8String + +case class FastOperator(output: Seq[Attribute]) extends SparkPlan { + + override protected def doExecute(): RDD[InternalRow] = { + val str = Literal("so fast").value + val row = new GenericInternalRow(Array[Any](str)) + sparkContext.parallelize(Seq(row)) + } + + override def children: Seq[SparkPlan] = Nil +} + +object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Project(Seq(attr), _) if attr.name == "a" => + FastOperator(attr.toAttribute :: Nil) :: Nil + case _ => Nil + } +} + +class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("insert an extraStrategy") { + try { + sqlContext.experimental.extraStrategies = TestStrategy :: Nil + + val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + checkAnswer( + df.select("a"), + Row("so fast")) + + checkAnswer( + df.select("a", "b"), + Row("so slow", 1)) + } finally { + sqlContext.experimental.extraStrategies = Nil + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d896603..9a3c262e9485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,36 +17,33 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SharedSQLContext + +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ -class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData + setupTestData() - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + def statisticSizeInByte(df: DataFrame): BigInt = { + df.queryExecution.optimizedPlan.statistics.sizeInBytes + } test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { - case j: ShuffledHashJoin => j - case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j case j: BroadcastHashOuterJoin => j @@ -55,6 +52,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -64,9 +62,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,13 +80,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -97,90 +94,47 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) - } } - test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) - } - } +// ignore("SortMergeJoin shouldn't work on unsortable columns") { +// Seq( +// ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) +// ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } +// } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) - } - - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) - } - - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -285,21 +239,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ - |SELECT l.N, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """. + stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -345,22 +300,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - Row(null, 6)) + Row(null, + 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -372,8 +329,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { upperCaseData.where('N <= 4).registerTempTable("left") upperCaseData.where('N >= 3).registerTempTable("right") - val left = UnresolvedRelation(Seq("left"), None) - val right = UnresolvedRelation(Seq("right"), None) + val left = UnresolvedRelation(TableIdentifier("left"), None) + val right = UnresolvedRelation(TableIdentifier("right"), None) checkAnswer( left.join(right, $"left.N" === $"right.N", "full"), @@ -404,24 +361,27 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) - // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join + // operator. checkAnswer( - ctx.sql( + sql( """ - |SELECT l.a, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """. + stripMargin), Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row + (1, 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -430,13 +390,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -445,42 +406,130 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ - |SELECT r.a, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """. + stripMargin), Row(null, 10)) } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + } - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + sql("UNCACHE TABLE testData") + } + + test("cross join with broadcast") { + sql("CACHE TABLE testData") + + val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + + // we set the threshold is greater than statistic of the cached table testData + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + + assert(statisticSizeInByte(sqlContext.table("testData2")) > + sqlContext.conf.autoBroadcastJoinThreshold) + + assert(statisticSizeInByte(sqlContext.table("testData")) < + sqlContext.conf.autoBroadcastJoinThreshold) + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", + classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2 + """.stripMargin), + Row("2", 1, 1) :: + Row("2", 1, 2) :: + Row("2", 2, 1) :: + Row("2", 2, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: @@ -488,6 +537,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala new file mode 100644 index 000000000000..1f384edf321b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("function get_json_object") { + val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") + checkAnswer( + df.selectExpr("get_json_object(a, '$.name')", "get_json_object(a, '$.age')"), + Row("alice", "5")) + } + + + val tuples: Seq[(String, String)] = + ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: + ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + ("3", """{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: + ("4", null) :: + ("5", """{"f1": "", "f5": null}""") :: + ("6", "[invalid JSON string]") :: + Nil + + test("function get_json_object - null") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.get_json_object($"jstring", "$.f1"), + functions.get_json_object($"jstring", "$.f2"), + functions.get_json_object($"jstring", "$.f3"), + functions.get_json_object($"jstring", "$.f4"), + functions.get_json_object($"jstring", "$.f5")), + expected) + } + + test("json_tuple select") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), + expected) + } + + test("json_tuple filter and group") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expr = df + .select(functions.json_tuple($"jstring", "f1", "f2")) + .where($"c0".isNotNull) + .groupBy($"c1") + .count() + + val expected = Row(null, 1) :: + Row("2", 2) :: + Row("value2", 1) :: + Nil + + checkAnswer(expr, expected) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf..5688f46e5e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.TableIdentifier -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -33,33 +33,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -67,20 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + sqlContext.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5957d8..58f982c2bc93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -39,9 +37,11 @@ class MathExpressionsSuite extends QueryTest { private lazy val nullDoubles = Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() - private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private def testOneToOneMathFunction[ + @specialized(Int, Long, Float, Double) T, + @specialized(Int, Long, Float, Double) U]( c: Column => Column, - f: T => T): Unit = { + f: T => U): Unit = { checkAnswer( doubleData.select(c('a)), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) @@ -149,7 +149,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -157,7 +157,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -167,10 +167,10 @@ class MathExpressionsSuite extends QueryTest { } test("ceil and ceiling") { - testOneToOneMathFunction(ceil, math.ceil) + testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), - Row(0.0, 1.0, 2.0)) + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + Row(0L, 1L, 2L)) } test("conv") { @@ -186,7 +186,7 @@ class MathExpressionsSuite extends QueryTest { } test("floor") { - testOneToOneMathFunction(floor, math.floor) + testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong) } test("factorial") { @@ -214,7 +214,7 @@ class MathExpressionsSuite extends QueryTest { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -230,10 +230,10 @@ class MathExpressionsSuite extends QueryTest { } test("signum / sign") { - testOneToOneMathFunction[Double](signum, math.signum) + testOneToOneMathFunction[Double, Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -241,7 +241,7 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -280,7 +280,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -375,7 +375,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -384,13 +384,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala new file mode 100644 index 000000000000..162c0b56c6e1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -0,0 +1,98 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark._ +import org.scalatest.BeforeAndAfterAll + +class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSQLContext: Option[SQLContext] = _ + private var originalInstantiatedSQLContext: Option[SQLContext] = _ + private var sparkConf: SparkConf = _ + + override protected def beforeAll(): Unit = { + originalActiveSQLContext = SQLContext.getActive() + originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + + SQLContext.clearActive() + SQLContext.clearInstantiatedContext() + sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + } + + def testNewSession(rootSQLContext: SQLContext): Unit = { + // Make sure we can successfully create new Session. + rootSQLContext.newSession() + + // Reset the state. It is always safe to clear the active context. + SQLContext.clearActive() + } + + def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = { + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString) + val sparkContext = new SparkContext(conf) + + try { + if (allowsMultipleContexts) { + new SQLContext(sparkContext) + SQLContext.clearActive() + } else { + // If allowsMultipleContexts is false, make sure we can get the error. + val message = intercept[SparkException] { + new SQLContext(sparkContext) + }.getMessage + assert(message.contains("Only one SQLContext/HiveContext may be running")) + } + } finally { + sparkContext.stop() + } + } + + test("test the flag to disallow creating multiple root SQLContext") { + Seq(false, true).foreach { allowMultipleSQLContexts => + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString) + val sc = new SparkContext(conf) + try { + val rootSQLContext = new SQLContext(sc) + testNewSession(rootSQLContext) + testNewSession(rootSQLContext) + testCreatingNewSQLContext(allowMultipleSQLContexts) + } finally { + sc.stop() + SQLContext.clearInstantiatedContext() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..bc22fb8b7bdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable -class QueryTest extends PlanTest { +abstract class QueryTest extends PlanTest { + + protected def sqlContext: SQLContext // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -51,36 +54,119 @@ class QueryTest extends PlanTest { } } + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + * - Special handling is done based on whether the query plan should be expected to return + * the results in sorted order. + * - This function also checks to make sure that the schema for serializing the expected answer + * matches that produced by the dataset (i.e. does manual construction of object match + * the constructed encoder for cases like joins, etc). Note that this means that it will fail + * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead + * which performs a subset of the checks done by this function. + */ + protected def checkAnswer[T]( + ds: Dataset[T], + expectedAnswer: T*): Unit = { + checkAnswer( + ds.toDF(), + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + + checkDecoding(ds, expectedAnswer: _*) + } + + protected def checkDecoding[T]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val decoded = try ds.collect().toSet catch { + case e: Exception => + fail( + s""" + |Exception collecting dataset as objects + |${ds.resolvedTEncoder} + |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.queryExecution} + """.stripMargin, e) + } + + if (decoded != expectedAnswer.toSet) { + val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted + val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted + + val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") + fail( + s"""Decoded objects do not match expected objects: + |$comparision + |${ds.resolvedTEncoder.fromRowExpression.treeString} + """.stripMargin) + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(df, expectedAnswer) match { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + val currentValue = sqlContext.conf.dataFrameEagerAnalysis + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val partiallyAnalzyedPlan = df.queryExecution.analyzed + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) + fail( + s""" + |Failed to analyze query: $ae + |$partiallyAnalzyedPlan + | + |${stackTraceToString(ae)} + |""".stripMargin) + } + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => } } - protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } - protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) } } + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached @@ -104,19 +190,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { @@ -150,8 +243,30 @@ object QueryTest { return None } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.toSeq) match { + checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage case None => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7865d6..3ba14d7602a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) @@ -58,7 +57,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] @@ -86,4 +85,13 @@ class RowSuite extends SparkFunSuite { val r2 = Row(Double.NaN) assert(r1 === r2) } + + test("equals and hashCode") { + val r1 = Row("Hello") + val r2 = Row("Hello") + assert(r1 === r2) + assert(r1.hashCode() === r2.hashCode()) + val r3 = Row("World") + assert(r3.hashCode() != r1.hashCode()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c2..3d2bd236ceea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,71 +17,79 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) + val newContext = new SQLContext(sparkContext) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) - - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + // Set a conf first. + sqlContext.setConf(testKey, testVal) + // Clear the conf. + sqlContext.conf.clear() + // After clear, only overrideConfs used by unit test should be in the SQLConf. + assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + sqlContext.conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + sqlContext.conf.clear() + sql(s"set $testKey=$testVal") + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + sql("set some.property=20") + assert(sqlContext.getConf("some.property", "0") === "20") + sql("set some.property = 40") + assert(sqlContext.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + sql(s"set $key=$vs") + assert(sqlContext.getConf(key, "0") === vs) - ctx.sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + sql(s"set $key=") + assert(sqlContext.getConf(key, "0") === "") - ctx.conf.clear() + sqlContext.conf.clear() } test("deprecated property") { - ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.conf.numShufflePartitions === 10) + sqlContext.conf.clear() + val original = sqlContext.conf.numShufflePartitions + try{ + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(sqlContext.conf.numShufflePartitions === 10) + } finally { + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + } } test("invalid conf value") { - ctx.conf.clear() + sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a4..1994dacfc4df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,32 +17,52 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.SparkFunSuite - -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) - } +class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("getOrCreate instantiates SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } - test("getOrCreate gets last explicitly instantiated SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, - "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + test("getOrCreate return the original SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val newSession = sqlContext.newSession() + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + SQLContext.setActive(newSession) + assert(SQLContext.getOrCreate(sc).eq(newSession), + "SQLContext.getOrCreate after explicitly setActive() did not return the active context") + } + + test("Sessions of SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val session1 = sqlContext.newSession() + val session2 = sqlContext.newSession() + + // all have the default configurations + val key = SQLConf.SHUFFLE_PARTITIONS.key + assert(session1.getConf(key) === session2.getConf(key)) + session1.setConf(key, "1") + session2.setConf(key, "2") + assert(session1.getConf(key) === "1") + assert(session2.getConf(key) === "2") + + // temporary table should not be shared + val df = session1.range(10) + df.registerTempTable("test1") + assert(session1.tableNames().contains("test1")) + assert(!session2.tableNames().contains("test1")) + + // UDF should not be shared + def myadd(a: Int, b: Int): Int = a + b + session1.udf.register[Int, Int, Int]("myadd", myadd) + session1.sql("select myadd(1, 2)").explain() + intercept[AnalysisException] { + session2.sql("select myadd(1, 2)").explain() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bbadc202a4f0..bb82b562aaaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,30 +17,27 @@ package org.apache.spark.sql +import java.math.MathContext import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.GeneratedAggregate +import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -60,7 +57,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { @@ -150,14 +148,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -178,7 +176,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) @@ -199,7 +197,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( + sqlContext.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -218,7 +216,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6201 IN type conversion") { sqlContext.read.json( - sqlContext.sparkContext.parallelize( + sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -227,66 +225,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-11226 Skip empty line in json file") { + sqlContext.read.json( + sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) + .registerTempTable("d") + + checkAnswer( + sql("select count(1) from d"), + Seq(Row(3))) + } + test("SPARK-8828 sum should return null if all input values are null") { - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: aggregate.TungstenAggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) } + // Then, check results. + checkAnswer(df, expectedResults) } test("aggregation with codegen") { - val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -358,7 +338,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -494,35 +473,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - def literalInAggTest(): Unit = { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - literalInAggTest() - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - literalInAggTest() - } + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) } test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) ) } @@ -580,30 +553,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("limit") { @@ -715,9 +666,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -840,6 +789,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(null, null, 6, "F") :: Nil) } + test("SPARK-11111 null-safe join should not use cartesian product") { + val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") + val cp = df.queryExecution.executedPlan.collect { + case cp: CartesianProduct => cp + } + assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") + val smj = df.queryExecution.executedPlan.collect { + case smj: SortMergeJoin => smj + } + assert(smj.size > 0, "should use SortMergeJoin") + checkAnswer(df, Row(100) :: Nil) + } + test("SPARK-3349 partitioning after limit") { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) @@ -999,21 +961,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size) + sql("SET").collect().foreach { row => + val key = row.getString(0) + val value = row.getString(1) + assert( + TestSQLContext.overrideConfs.contains(key), + s"$key should exist in SQLConf.") + assert( + TestSQLContext.overrideConfs(key) === value, + s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.") + } + val overrideConfs = sql("SET").collect() // "set key=val" sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(testKey, testVal) + overrideConfs ++ Seq(Row(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), - Seq( - Row(testKey, testVal), - Row(testKey + testKey, testVal + testVal)) + overrideConfs ++ Seq(Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) ) // "set key" @@ -1164,7 +1135,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1349,7 +1321,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") @@ -1392,13 +1364,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } @@ -1419,10 +1391,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1431,7 +1403,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1439,14 +1411,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-4699 case sensitivity SQL query") { sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1459,14 +1431,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1497,6 +1469,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ @@ -1516,6 +1498,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { @@ -1540,7 +1542,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1604,4 +1606,426 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df.select(-df("i")), Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + + test("aggregation with codegen updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + + test("decimal precision with multiply/division") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) + } + + test("precision smaller than scale") { + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + + test("external sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() + } + } + + test("SPARK-9511: error with table starting with number") { + withTempTable("1one") { + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + .registerTempTable("1one") + checkAnswer(sql("select count(num) from 1one"), Row(10)) + } + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } + + test("SPARK-10130 type coercion for IF should have children resolved first") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } + } + + test("run sql directly on files") { + val df = sqlContext.range(100) + withTempPath(f => { + df.write.json(f.getCanonicalPath) + checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select id from `org.apache.spark.sql.json`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select a.id from json.`${f.getCanonicalPath}` as a"), + df) + }) + + val e1 = intercept[AnalysisException] { + sql("select * from in_valid_table") + } + assert(e1.message.contains("Table not found")) + + val e2 = intercept[AnalysisException] { + sql("select * from no_db.no_table") + } + assert(e2.message.contains("Table not found")) + + val e3 = intercept[AnalysisException] { + sql("select * from json.invalid_file") + } + assert(e3.message.contains("No input paths specified")) + } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } + + test("SPARK-11032: resolve having correctly") { + withTempTable("src") { + Seq(1 -> "a").toDF("i", "j").registerTempTable("src") + checkAnswer( + sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"), + Row(1)) + } + } + + test("SPARK-11303: filter should not be pushed down into sample") { + val df = sqlContext.range(100) + List(true, false).foreach { withReplacement => + val sampled = df.sample(withReplacement, 0.1, 1) + val sampledOdd = sampled.filter("id % 2 != 0") + val sampledEven = sampled.filter("id % 2 = 0") + assert(sampled.count() == sampledOdd.count() + sampledEven.count()) + } + } + + test("Struct Star Expansion") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + structDf.select($"record.a", $"record.b"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*", $"record.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 1, 2) :: Row(2, 1, 2, 1) :: Row(2, 2, 2, 2) :: + Row(3, 1, 3, 1) :: Row(3, 2, 3, 2) :: Nil) + + checkAnswer( + sql("select struct(a, b) as r1, struct(b, a) as r2 from testData2").select($"r1.*", $"r2.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) :: + Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil) + + // Try with a registered table. + sql("select struct(a, b) as record from testData2").registerTempTable("structTable") + checkAnswer( + sql("SELECT record.* FROM structTable"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin), + Row(Row(1, 1)) :: Nil) + + // Try with an alias on the select list + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin).select($"r.*"), + Row(3, 2) :: Nil) + + // With GROUP BY + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin), + Row(Row(1, 1)) :: Row(Row(2, 1)) :: Row(Row(3, 1)) :: Nil) + + // With GROUP BY and alias + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil) + + // With GROUP BY and alias and additional fields in the struct + checkAnswer(sql( + """ + | SELECT max(struct(a, record.*, b)) as r FROM + | (select a as a, b as b, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 1, 2, 2) :: Row(2, 2, 2, 2) :: Row(3, 3, 2, 2) :: Nil) + + // Create a data set that contains nested structs. + val nestedStructData = sql( + """ + | SELECT struct(r1, r2) as record FROM + | (SELECT struct(a, b) as r1, struct(b, a) as r2 FROM testData2) tmp + """.stripMargin) + + checkAnswer(nestedStructData.select($"record.*"), + Row(Row(1, 1), Row(1, 1)) :: Row(Row(1, 2), Row(2, 1)) :: Row(Row(2, 1), Row(1, 2)) :: + Row(Row(2, 2), Row(2, 2)) :: Row(Row(3, 1), Row(1, 3)) :: Row(Row(3, 2), Row(2, 3)) :: Nil) + checkAnswer(nestedStructData.select($"record.r1"), + Row(Row(1, 1)) :: Row(Row(1, 2)) :: Row(Row(2, 1)) :: Row(Row(2, 2)) :: + Row(Row(3, 1)) :: Row(Row(3, 2)) :: Nil) + checkAnswer( + nestedStructData.select($"record.r1.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + // Try with a registered table + withTempTable("nestedStructTable") { + nestedStructData.registerTempTable("nestedStructTable") + checkAnswer( + sql("SELECT record.* FROM nestedStructTable"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT record.r1 FROM nestedStructTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT record.r1.* FROM nestedStructTable"), + nestedStructData.select($"record.r1.*")) + + // Try resolving something not there. + assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) + .getMessage.contains("cannot resolve")) + } + + // Create paths with unusual characters + val specialCharacterPath = sql( + """ + | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM + | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp + """.stripMargin) + withTempTable("specialCharacterTable") { + specialCharacterPath.registerTempTable("specialCharacterTable") + checkAnswer( + specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + } + + // Try star expanding a scalar. This should fail. + assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains( + "Can only star expand struct data types.")) + } + + test("Struct Star Expansion - Name conflict") { + // Create a data set that contains a naming conflict + val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") + withTempTable("nameConflict") { + nameConflict.registerTempTable("nameConflict") + // Unqualified should resolve to table. + checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), + Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: + Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) + // Qualify the struct type with the table name. + checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } + } + + test("Common subexpression elimination") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd96d27..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -71,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b79..ddab91862964 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ab5da6ee79f1..e2090b0a83ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,11 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.test.SharedSQLContext -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") @@ -128,6 +126,12 @@ class StringFunctionsSuite extends QueryTest { // scalastyle:on } + test("string translate") { + val df = Seq(("translate", "")).toDF("a", "b") + checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) + checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae")) + } + test("string trim functions") { val df = Seq((" example ", "")).toDF("a", "b") @@ -263,9 +267,7 @@ class StringFunctionsSuite extends QueryTest { Row(3, 4)) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) + df.selectExpr("length(c)") // int type of the argument is unacceptable } } @@ -279,63 +281,46 @@ class StringFunctionsSuite extends QueryTest { } test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select(format_number($"f", 4)), + val df = sqlContext.range(1) + + checkAnswer( + df.select(format_number(lit(5L), 4)), Row("5.0000")) checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer Row("1.0000")) checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer Row("2.0000")) checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double Row("3.1322")) checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything + df.select(format_number(lit(4), 4)), // not convert anything Row("4.0000")) checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything + df.select(format_number(lit(5L), 4)), // not convert anything Row("5.0000")) checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything + df.select(format_number(lit(6.48173), 4)), // not convert anything Row("6.4817")) checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything + df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything Row("7.1284")) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) + df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable } intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) + df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable } // for testing the mutable state of the expression in code gen. @@ -343,9 +328,9 @@ class StringFunctionsSuite extends QueryTest { // it will still use the interpretProjection if projection follows by a LocalRelation, // hence we add a filter operator. // See the optimizer rule `ConvertToLocalRelation` - val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + val df2 = Seq((5L, 4), (4L, 3), (4L, 3), (4L, 3), (3L, 2)).toDF("a", "b") checkAnswer( df2.filter("b>0").selectExpr("format_number(a, b)"), - Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) + Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index bd9729c431f3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 183dc3407b3a..fd736718af12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,19 +17,16 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("built-in fixed arity expressions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -57,24 +54,24 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - ctx.dropTempTable("tmp_table") + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + sqlContext.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => - val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") - val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) - assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) - ctx.dropTempTable("test_table") + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) + sqlContext.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -82,7 +79,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { } test("error reporting for undefined functions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -90,41 +87,41 @@ class UDFSuite extends QueryTest with SQLTestUtils { } test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + sqlContext.udf.register("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + sqlContext.udf.register("random0", () => { Math.random()}) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = ctx.sparkContext.parallelize( + val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("integerData") val result = - ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } test("UDF in a HAVING") { - ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT g, SUM(v) as s | FROM groupData @@ -136,14 +133,14 @@ class UDFSuite extends QueryTest with SQLTestUtils { } test("UDF in a GROUP BY") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT SUM(v) | FROM groupData @@ -153,17 +150,17 @@ class UDFSuite extends QueryTest with SQLTestUtils { } test("UDFs everywhere") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) - ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) - ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) - ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) + sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) + sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData @@ -175,23 +172,79 @@ class UDFSuite extends QueryTest with SQLTestUtils { } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - ctx.udf.register("intExpected", (x: Int) => x) + sqlContext.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. - assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + } + + test("udf in different types") { + sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + sqlContext.udf.register("decimalDataFunc", + (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) + sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + sqlContext.udf.register("arrayDataFunc", + (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) + sqlContext.udf.register("mapDataFunc", + (data: scala.collection.Map[Int, String]) => { data }) + sqlContext.udf.register("complexDataFunc", + (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(key, value) AS t from testData) tmp").toDF(), + testData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT decimalDataFunc(a, b) AS t FROM decimalData) tmp + """.stripMargin).toDF(), decimalData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT binaryDataFunc(a, b) AS t FROM binaryData) tmp + """.stripMargin).toDF(), binaryData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp + """.stripMargin).toDF(), arrayData.toDF()) + checkAnswer( + sql(""" + | SELECT mapDataFunc(data) AS t FROM mapData + """.stripMargin).toDF(), mapData.toDF()) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp + """.stripMargin).toDF(), complexData.select("m", "a", "b")) + } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index c5faaa663e74..00f1526576cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -19,29 +19,67 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { + + test("UnsafeRow Java serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new JavaSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + + test("UnsafeRow Kryo serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new KryoSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + + test("bitset width calculation") { + assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) + assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(32) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(64) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(65) === 16) + assert(UnsafeRow.calculateBitSetWidthInBytes(128) === 16) + } + test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) - val bytesFromArrayBackedRow: Array[Byte] = { + val (bytesFromArrayBackedRow, field0StringFromArrayBackedRow): (Array[Byte], String) = { val baos = new ByteArrayOutputStream() arrayBackedUnsafeRow.writeToStream(baos, null) - baos.toByteArray + (baos.toByteArray, arrayBackedUnsafeRow.getString(0)) } - val bytesFromOffheapRow: Array[Byte] = { + val (bytesFromOffheapRow, field0StringFromOffheapRow): (Array[Byte], String) = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { - PlatformDependent.copyMemory( + Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, offheapRowPage.getBaseObject, @@ -59,13 +97,14 @@ class UnsafeRowSuite extends SparkFunSuite { val baos = new ByteArrayOutputStream() val writeBuffer = new Array[Byte](1024) offheapUnsafeRow.writeToStream(baos, writeBuffer) - baos.toByteArray + (baos.toByteArray, offheapUnsafeRow.getString(0)) } finally { MemoryAllocator.UNSAFE.free(offheapRowPage) } } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) } test("calling getDouble() and getFloat() on null columns") { @@ -120,4 +159,11 @@ class UnsafeRowSuite extends SparkFunSuite { assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) } + + test("calling hashCode on unsafe array returned by getArray(ordinal)") { + val row = InternalRow.apply(new GenericArrayData(Array(1L))) + val unsafeRow = UnsafeProjection.create(Array[DataType](ArrayType(LongType))).apply(row) + // Makes sure hashCode on unsafe array won't crash + unsafeRow.getArray(0).hashCode() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f29935224e5b..f602f2fb89ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql -import scala.beans.{BeanInfo, BeanProperty} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} -import com.clearspring.analytics.stream.cardinality.HyperLogLog +import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -91,24 +90,35 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } - - test("UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } - test("Repartition UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("Repartition UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.repartition(1).write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } // Tests to make sure that all operators correctly convert types on the way out. @@ -120,22 +130,38 @@ class UserDefinedTypeSuite extends QueryTest { df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("HyperLogLogUDT") { - val hyperLogLogUDT = HyperLogLogUDT - val hyperLogLog = new HyperLogLog(0.4) - (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + test("UDTs with JSON") { + val data = Seq( + "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", + "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" + ) + val schema = StructType(Seq( + StructField("id", IntegerType, false), + StructField("vec", new MyDenseVectorUDT, false) + )) + + val stringRDD = sparkContext.parallelize(data) + val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + checkAnswer( + jsonRDD, + Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Nil + ) + } - val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) - assert(actual.cardinality() === hyperLogLog.cardinality()) - assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + test("SPARK-10472 UserDefinedType.typeName") { + assert(IntegerType.typeName === "integer") + assert(new MyDenseVectorUDT().typeName === "mydensevector") } - test("OpenHashSetUDT") { - val openHashSetUDT = new OpenHashSetUDT(IntegerType) - val set = new OpenHashSet[Int] - (1 to 10).foreach(i => set.add(i)) + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) - val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) - assert(actual.iterator.toSet === set.iterator.toSet) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala new file mode 100644 index 000000000000..f54e23e3aa6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala @@ -0,0 +1,38 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.r + +import org.apache.spark.sql.test.SharedSQLContext + +class SQLUtilsSuite extends SharedSQLContext { + + import testImplicits._ + + test("dfToCols should collect and transpose a data frame") { + val df = Seq( + (1, 2, 3), + (4, 5, 6) + ).toDF + assert(SQLUtils.dfToCols(df) === Array( + Array(1, 4), + Array(2, 5), + Array(3, 6) + )) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala deleted file mode 100644 index 8f024690efd0..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import java.nio.ByteBuffer - -import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Kryo, Serializer} - -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -class ColumnTypeSuite extends SparkFunSuite with Logging { - private val DEFAULT_BUFFER_SIZE = 512 - private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) - - test("defaultSize") { - val checks = Map( - BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, - LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, - MAP_GENERIC -> 16) - - checks.foreach { case (columnType, expectedSize) => - assertResult(expectedSize, s"Wrong defaultSize for $columnType") { - columnType.defaultSize - } - } - } - - test("actualSize") { - def checkActualSize[JvmType]( - columnType: ColumnType[JvmType], - value: JvmType, - expected: Int): Unit = { - - assertResult(expected, s"Wrong actualSize for $columnType") { - val row = new GenericMutableRow(1) - columnType.setField(row, 0, value) - columnType.actualSize(row, 0) - } - } - - checkActualSize(BOOLEAN, true, 1) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(SHORT, Short.MaxValue, 2) - checkActualSize(INT, Int.MaxValue, 4) - checkActualSize(DATE, Int.MaxValue, 4) - checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(TIMESTAMP, Long.MaxValue, 8) - checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - - val generic = Map(1 -> "a") - checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) - } - - testNativeColumnType(BOOLEAN)( - (buffer: ByteBuffer, v: Boolean) => { - buffer.put((if (v) 1 else 0).toByte) - }, - (buffer: ByteBuffer) => { - buffer.get() == 1 - }) - - testNativeColumnType(BYTE)(_.put(_), _.get) - - testNativeColumnType(SHORT)(_.putShort(_), _.getShort) - - testNativeColumnType(INT)(_.putInt(_), _.getInt) - - testNativeColumnType(DATE)(_.putInt(_), _.getInt) - - testNativeColumnType(LONG)(_.putLong(_), _.getLong) - - testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - - testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - - testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - - testNativeColumnType(FIXED_DECIMAL(15, 10))( - (buffer: ByteBuffer, decimal: Decimal) => { - buffer.putLong(decimal.toUnscaledLong) - }, - (buffer: ByteBuffer) => { - Decimal(buffer.getLong(), 15, 10) - }) - - - testNativeColumnType(STRING)( - (buffer: ByteBuffer, string: UTF8String) => { - val bytes = string.getBytes - buffer.putInt(bytes.length) - buffer.put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes) - UTF8String.fromBytes(bytes) - }) - - testColumnType[Array[Byte]]( - BINARY, - (buffer: ByteBuffer, bytes: Array[Byte]) => { - buffer.putInt(bytes.length).put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - bytes - }) - - test("GENERIC") { - val buffer = ByteBuffer.allocate(512) - val obj = Map(1 -> "spark", 2 -> "sql") - val serializedObj = SparkSqlSerializer.serialize(obj) - - MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) - buffer.rewind() - - val length = buffer.getInt() - assert(length === serializedObj.length) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - SparkSqlSerializer.deserialize(bytes) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - buffer.rewind() - SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) - } - } - - test("CUSTOM") { - val conf = new SparkConf() - conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") - val serializer = new SparkSqlSerializer(conf).newInstance() - - val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue, Long.MaxValue) - val serializedObj = serializer.serialize(obj).array() - - MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) - buffer.rewind() - - val length = buffer.getInt - assert(length === serializedObj.length) - assert(13 == length) // id (1) + int (4) + long (8) - - val genericSerializedObj = SparkSqlSerializer.serialize(obj) - assert(length != genericSerializedObj.length) - assert(length < genericSerializedObj.length) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - serializer.deserialize(ByteBuffer.wrap(bytes)) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) - } - } - - def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T]) - (putter: (ByteBuffer, T#InternalType) => Unit, - getter: (ByteBuffer) => T#InternalType): Unit = { - - testColumnType[T#InternalType](columnType, putter, getter) - } - - def testColumnType[JvmType]( - columnType: ColumnType[JvmType], - putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType): Unit = { - - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val seq = (0 until 4).map(_ => makeRandomValue(columnType)) - - test(s"$columnType.extract") { - buffer.rewind() - seq.foreach(putter(buffer, _)) - - buffer.rewind() - seq.foreach { expected => - logInfo("buffer = " + buffer + ", expected = " + expected) - val extracted = columnType.extract(buffer) - assert( - expected === extracted, - "Extracted value didn't equal to the original one. " + - hexDump(expected) + " != " + hexDump(extracted) + - ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) - } - } - - test(s"$columnType.append") { - buffer.rewind() - seq.foreach(columnType.append(_, buffer)) - - buffer.rewind() - seq.foreach { expected => - assert( - expected === getter(buffer), - "Extracted value didn't equal to the original one") - } - } - } - - private def hexDump(value: Any): String = { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") - } - - private def dumpBuffer(buff: ByteBuffer): Any = { - val sb = new StringBuilder() - while (buff.hasRemaining) { - val b = buff.get() - sb.append(Integer.toHexString(b & 0xff)).append(' ') - } - if (sb.nonEmpty) sb.setLength(sb.length - 1) - sb.toString() - } - - test("column type for decimal types with different precision") { - (1 to 18).foreach { i => - assertResult(FIXED_DECIMAL(i, 0)) { - ColumnType(DecimalType(i, 0)) - } - } - - assertResult(GENERIC(DecimalType(19, 0))) { - ColumnType(DecimalType(19, 0)) - } - } -} - -private[columnar] final case class CustomClass(a: Int, b: Long) - -private[columnar] object CustomerSerializer extends Serializer[CustomClass] { - override def write(kryo: Kryo, output: Output, t: CustomClass) { - output.writeInt(t.a) - output.writeLong(t.b) - } - override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { - val a = input.readInt() - val b = input.readLong() - CustomClass(a, b) - } -} - -private[columnar] final class Registrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[CustomClass], CustomerSerializer) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala deleted file mode 100644 index 20def6bef0c1..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext - -class AggregateSuite extends SparkPlanTest { - - test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { - val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) - val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) - try { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) - val df = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( - df, - GeneratedAggregate( - partial = true, - Seq(df.col("b").expr), - Seq(Alias(Count(df.col("a").expr), "cnt")()), - unsafeEnabled = true, - _: SparkPlan), - Seq.empty - ) - } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala new file mode 100644 index 000000000000..4ff96e6574ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper + +class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator + val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + assert(result == + (1, + Seq(create_row(1, "a"), create_row(1, "b")), + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "c")), + Seq(create_row(2, 3L))) :: + (3, + Seq.empty, + Seq(create_row(3, 4L))) :: + Nil + ) + } + + test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { + val leftInput = Seq(create_row(2, "a")).iterator + val rightInput = Seq(create_row(1, 2L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + + assert(result == + (1, + Seq.empty, + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "a")), + Seq.empty) :: + Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala new file mode 100644 index 000000000000..180050bdac00 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql._ +import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, MapOutputStatistics} + +class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSQLContext: Option[SQLContext] = _ + private var originalInstantiatedSQLContext: Option[SQLContext] = _ + + override protected def beforeAll(): Unit = { + originalActiveSQLContext = SQLContext.getActive() + originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + + SQLContext.clearActive() + SQLContext.clearInstantiatedContext() + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + } + + private def checkEstimation( + coordinator: ExchangeCoordinator, + bytesByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int]): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { + case (bytesByPartitionId, index) => + new MapOutputStatistics(index, bytesByPartitionId) + } + val estimatedPartitionStartIndices = + coordinator.estimatePartitionStartIndices(mapOutputStatistics) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + } + + test("test estimatePartitionStartIndices - 1 Exchange") { + val coordinator = new ExchangeCoordinator(1, 100L) + + { + // All bytes per partition are 0. + val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // Some bytes per partition are 0 and total size is less than the target size. + // 1 post-shuffle partition is needed. + val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partitions are needed. + val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // There are a few large pre-shuffle partitions. + val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) + val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // All pre-shuffle partitions are larger than the targeted size. + val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // The last pre-shuffle partition is in a single post-shuffle partition. + val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) + val expectedPartitionStartIndices = Array[Int](0, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + } + + test("test estimatePartitionStartIndices - 2 Exchanges") { + val coordinator = new ExchangeCoordinator(2, 100L) + + { + // If there are multiple values of the number of pre-shuffle partitions, + // we should see an assertion error. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) + val mapOutputStatistics = + Array( + new MapOutputStatistics(0, bytesByPartitionId1), + new MapOutputStatistics(1, bytesByPartitionId2)) + intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) + } + + { + // All bytes per partition are 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // Some bytes per partition are 0. + // 1 post-shuffle partition is needed. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // There are a few large pre-shuffle partitions. + val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) + val expectedPartitionStartIndices = Array[Int](0, 2, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // All pairs of pre-shuffle partitions are larger than the targeted size. + val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + } + + test("test estimatePartitionStartIndices and enforce minimal number of reducers") { + val coordinator = new ExchangeCoordinator(2, 100L, Some(2)) + + { + // The minimal number of post-shuffle partitions is not enforced because + // the size of data is 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // The minimal number of post-shuffle partitions is enforced. + val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) + val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // The number of post-shuffle partitions is determined by the coordinator. + val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) + val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + } + + /////////////////////////////////////////////////////////////////////////// + // Query tests + /////////////////////////////////////////////////////////////////////////// + + val numInputPartitions: Int = 10 + + def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(actual, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + def withSQLContext( + f: SQLContext => Unit, + targetNumPostShufflePartitions: Int, + minNumPostShufflePartitions: Option[Int]): Unit = { + val sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set( + SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, + targetNumPostShufflePartitions.toString) + minNumPostShufflePartitions match { + case Some(numPartitions) => + sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) + case None => + sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") + } + val sparkContext = new SparkContext(sparkConf) + val sqlContext = new TestSQLContext(sparkContext) + try f(sqlContext) finally sparkContext.stop() + } + + Seq(Some(3), None).foreach { minNumPostShufflePartitions => + val testNameNote = minNumPostShufflePartitions match { + case Some(numPartitions) => "(minNumPostShufflePartitions: 3)" + case None => "" + } + + test(s"determining the number of reducers: aggregate operator$testNameNote") { + val test = { sqlContext: SQLContext => + val df = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 20 as key", "id as value") + val agg = df.groupBy("key").count + + // Check the answer first. + checkAnswer( + agg, + sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = agg.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 1) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 2) + case o => + } + } + } + + withSQLContext(test, 1536, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: join operator$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 2) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 2) + case o => + } + } + } + + withSQLContext(test, 16384, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 1$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count + .toDF("key1", "cnt1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + .groupBy("key2") + .count + .toDF("key2", "cnt2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("cnt2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 500) + .selectExpr("id", "2 as cnt") + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 4) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + assert(exchanges.forall(_.coordinator.isDefined)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(1, 2)) + } + } + + withSQLContext(test, 6144, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 2$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count + .toDF("key1", "cnt1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 1000) + .selectExpr("id % 500 as key", "2 as cnt", "id as value") + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 3) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + assert(exchanges.forall(_.coordinator.isDefined)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(2, 3)) + } + } + + withSQLContext(test, 6144, minNumPostShufflePartitions) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 79e903c2bbd4..911d12e93e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext + +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder -class ExchangeSuite extends SparkPlanTest { test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala new file mode 100644 index 000000000000..faef76d52ae7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType + +class ExpandSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private def testExpand(f: SparkPlan => SparkPlan): Unit = { + val input = (1 to 1000).map(Tuple1.apply) + val projections = Seq.tabulate(2) { i => + Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil + } + val attributes = projections.head.map(_.toAttribute) + checkAnswer( + input.toDF(), + plan => Expand(projections, attributes, f(plan)), + input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) + ) + } + + test("inheriting child row type") { + val exprs = AttributeReference("a", IntegerType, false)() :: Nil + val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) + assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") + } + + test("expanding UnsafeRows") { + testExpand(ConvertToUnsafe) + } + + test("expanding SafeRows") { + testExpand(identity) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala new file mode 100644 index 000000000000..e7a08481cfa8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} + +class GroupedIteratorSuite extends SparkFunSuite { + + test("basic") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 1) + key.getInt(0) -> data.map(encoder.fromRow).toSeq + }.toSeq + + assert(result == + 1 -> Seq(input(0), input(1)) :: + 2 -> Seq(input(2)) :: Nil) + } + + test("group by 2 columns") { + val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) + val encoder = RowEncoder(schema) + + val input = Seq( + Row(1, 2L, "a"), + Row(1, 2L, "b"), + Row(1, 3L, "c"), + Row(2, 1L, "d"), + Row(3, 2L, "e")) + + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 2) + (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) + }.toSeq + + assert(result == + (1, 2L, Seq(input(0), input(1))) :: + (1, 3L, Seq(input(2))) :: + (2, 1L, Seq(input(3))) :: + (3, 2L, Seq(input(4))) :: Nil) + } + + test("do nothing to the value iterator") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + assert(grouped.length == 2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c5..2fb439f50117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,39 +17,43 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.joins.{SortMergeJoin, BroadcastHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = TestSQLContext + setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planner = sqlContext.planner + import planner._ + val plannedOption = Aggregation(query).headOption val planned = plannedOption.getOrElse( fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be three aggregate operator for + // For the new aggregation code path, there will be four aggregate operator for // distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 3, + aggregations.size == 2 || aggregations.size == 4, s"The plan of query $query does not have partial aggregations.") } test("unions are collapsed") { + val planner = sqlContext.planner + import planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -76,33 +80,30 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) - val fields = fieldTypes.zipWithIndex.map { - case (dataType, index) => StructField(s"c${index}", dataType, true) - } :+ StructField("key", IntegerType, true) - val schema = StructType(fields) - val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") - - val planned = sql( - """ - |SELECT l.a, l.b - |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan - - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - - dropTempTable("testLimit") + def checkPlan(fieldTypes: Seq[DataType]): Unit = { + withTempTable("testLimit") { + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(sortMergeJoins.isEmpty, "Should not use sort merge join") + } } - val origThreshold = conf.autoBroadcastJoinThreshold - val simpleTypes = NullType :: BooleanType :: @@ -119,7 +120,9 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { StringType :: BinaryType :: Nil - checkPlan(simpleTypes, newThreshold = 16434) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") { + checkPlan(simpleTypes) + } val complexTypes = ArrayType(DoubleType, true) :: @@ -131,35 +134,64 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false))) :: Nil - checkPlan(complexTypes, newThreshold = 901617) - - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") { + checkPlan(complexTypes) + } } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { + withTempTable("tiny") { + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") + + val a = testData.as("a") + val b = sqlContext.table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan - testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } - val a = testData.as("a") - val b = table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(sortMergeJoins.isEmpty, "Should not use sort merge join") - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + sqlContext.clearCache() + } + } + } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + test("SPARK-11390 explain should print PushedFilters of PhysicalRDD") { + withTempPath { file => + val path = file.getCanonicalPath + testData.write.parquet(path) + val df = sqlContext.read.parquet(path) + sqlContext.registerDataFrameAsTable(df, "testPushed") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withTempTable("testPushed") { + val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) + } + } } test("efficient limit -> project -> sort") { - val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = sqlContext.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } + + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = sqlContext.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } test("PartitioningCollection") { @@ -202,4 +234,200 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } } } + + // --- Unit tests of EnsureRequirements --------------------------------------------------------- + + // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, + // there two dimensions that need to be considered: are the child partitionings compatible and + // do they satisfy the distribution requirements? As a result, we need at least four test cases. + + private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { + if (outputPlan.children.length > 1 + && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { + val childPartitionings = outputPlan.children.map(_.outputPartitioning) + if (!Partitioning.allCompatible(childPartitionings)) { + fail(s"Partitionings are not compatible: $childPartitionings") + } + } + outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { + case (child, requiredDist) => + assert(child.outputPartitioning.satisfies(requiredDist), + s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan") + } + } + + test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { + // Consider an operator that requires inputs that are clustered by two expressions (e.g. + // sort merge join where there are multiple columns in the equi-join condition) + val clusteringA = Literal(1) :: Nil + val clusteringB = Literal(2) :: Nil + val distribution = ClusteredDistribution(clusteringA ++ clusteringB) + // Say that the left and right inputs are each partitioned by _one_ of the two join columns: + val leftPartitioning = HashPartitioning(clusteringA, 1) + val rightPartitioning = HashPartitioning(clusteringB, 1) + // Individually, each input's partitioning satisfies the clustering distribution: + assert(leftPartitioning.satisfies(distribution)) + assert(rightPartitioning.satisfies(distribution)) + // However, these partitionings are not compatible with each other, so we still need to + // repartition both inputs prior to performing the join: + assert(!leftPartitioning.compatibleWith(rightPartitioning)) + assert(!rightPartitioning.compatibleWith(leftPartitioning)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = leftPartitioning), + DummySparkPlan(outputPartitioning = rightPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with child partitionings with different numbers of output partitions") { + // This is similar to the previous test, except it checks that partitionings are not compatible + // unless they produce the same number of partitions. + val clustering = Literal(1) :: Nil + val distribution = ClusteredDistribution(clustering) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)), + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2)) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + } + + test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // The left and right inputs have compatible partitionings but they do not satisfy the + // distribution because they are clustered on different columns. Thus, we need to shuffle. + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { + // In this case, all requirements are satisfied and no exchange should be added. + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + fail(s"Exchange should not have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-9703 + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { + // Consider an operator that imposes both output distribution and ordering requirements on its + // children, such as sort sort merge join. If the distribution requirements are satisfied but + // the output ordering requirements are unsatisfied, then the planner should only add sorts and + // should not need to add additional shuffles / exchanges. + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = SinglePartition), + DummySparkPlan(outputPartitioning = SinglePartition) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(outputOrdering, outputOrdering) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements adds sort when there is no existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, + requiredChildOrdering = Seq(Seq(orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: Sort => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA, orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + + // --------------------------------------------------------------------------------------------- +} + +// Used for unit-testing EnsureRequirements +private case class DummySparkPlan( + override val children: Seq[SparkPlan] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val requiredChildDistribution: Seq[Distribution] = Nil, + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil + ) extends SparkPlan { + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala new file mode 100644 index 000000000000..9575d26fd123 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + + +/** + * A reference sort implementation used to compare against our normal sort. + */ +case class ReferenceSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 707cd9c6d939..2328899bb2f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -17,42 +17,82 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{ArrayType, StringType} +import org.apache.spark.unsafe.types.UTF8String -class RowFormatConvertersSuite extends SparkPlanTest { +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } + test("coalesce can process unsafe rows") { + val plan = Coalesce(1, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except can process unsafe rows") { + val plan = Except(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except requires all of its input rows' formats to agree") { + val plan = Except(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect can process unsafe rows") { + val plan = Intersect(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect requires all of its input rows' formats to agree") { + val plan = Intersect(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + test("execute() fails an assertion if inputs rows are of different formats") { val e = intercept[AssertionError] { Union(Seq(outputsSafe, outputsUnsafe)).execute() @@ -63,28 +103,62 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - TestSQLContext.createDataFrame(input), + sqlContext.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SQLContext.setActive(sqlContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 000000000000..63639681ef80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a2c10fdaf6cd..e5d34be4c65e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,13 +17,22 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{RandomDataGenerator, Row} + -class SortSuite extends SparkPlanTest { +/** + * Test sorting. Many of the test cases generate random data and compares the sorted result with one + * sorted by a reference implementation ([[ReferenceSort]]). + */ +class SortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder - // This test was originally added as an example of how to use [[SparkPlanTest]]; - // it's not designed to be a comprehensive test of ExternalSort. test("basic sorting using ExternalSort") { val input = Seq( @@ -34,14 +43,66 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkThatPlansAgree( + inputDf, + p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f46855edfe0d..8549a6a0f664 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,30 +17,21 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.test.SQLTestUtils + /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - protected def sqlContext: SQLContext = TestSQLContext - - /** - * Creates a DataFrame from a local Seq of Product. - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - sqlContext.implicits.localSeqToDataFrameHolder(data) - } +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def sqlContext: SQLContext /** * Runs the plan and makes sure the answer matches the expected result. @@ -186,7 +177,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -231,7 +222,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -240,46 +231,6 @@ object SparkPlanTest { } } - private def compareAnswers( - sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - if (sort) { - converted.sortBy(_.toString()) - } else { - converted - } - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | == Results == - | ${sideBySide( - s"== Expected Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Actual Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - Some(errorMessage) - } else { - None - } - } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. @@ -287,14 +238,14 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } ) - resolvedPlan.executeCollect().toSeq + resolvedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala deleted file mode 100644 index 7978ed57a937..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.sql.{Timestamp, Date} - -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.{ShuffleDependency, SparkFunSuite} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} - -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { - // Make sure that we will not use serializer2 for unsupported data types. - def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { - val testName = - s"${if (dataType == null) null else dataType.toString} is " + - s"${if (isSupported) "supported" else "unsupported"}" - - test(testName) { - assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) - } - } - - checkSupported(null, isSupported = true) - checkSupported(BooleanType, isSupported = true) - checkSupported(ByteType, isSupported = true) - checkSupported(ShortType, isSupported = true) - checkSupported(IntegerType, isSupported = true) - checkSupported(LongType, isSupported = true) - checkSupported(FloatType, isSupported = true) - checkSupported(DoubleType, isSupported = true) - checkSupported(DateType, isSupported = true) - checkSupported(TimestampType, isSupported = true) - checkSupported(StringType, isSupported = true) - checkSupported(BinaryType, isSupported = true) - checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) - - // If NullType is the only data type in the schema, we do not support it. - checkSupported(NullType, isSupported = false) - // For now, ArrayType, MapType, and StructType are not supported. - checkSupported(ArrayType(DoubleType, true), isSupported = false) - checkSupported(ArrayType(StringType, false), isSupported = false) - checkSupported(MapType(IntegerType, StringType, true), isSupported = false) - checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) - checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) - // UDTs are not supported right now. - checkSupported(new MyDenseVectorUDT, isSupported = false) -} - -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - var allColumns: String = _ - val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] - var numShufflePartitions: Int = _ - var useSerializer2: Boolean = _ - - protected lazy val ctx = TestSQLContext - - override def beforeAll(): Unit = { - numShufflePartitions = ctx.conf.numShufflePartitions - useSerializer2 = ctx.conf.useSqlSerializer2 - - ctx.sql("set spark.sql.useSerializer2=true") - - val supportedTypes = - Seq(StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), - DateType, TimestampType) - - val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, true) - } - allColumns = fields.map(_.name).mkString(",") - val schema = StructType(fields) - - // Create a RDD with all data types supported by SparkSqlSerializer2. - val rdd = - ctx.sparkContext.parallelize((1 to 1000), 10).map { i => - Row( - s"str${i}: test serializer2.", - s"binary${i}: test serializer2.".getBytes("UTF-8"), - null, - i % 2 == 0, - i.toByte, - i.toShort, - i, - Long.MaxValue - i.toLong, - (i + 0.25).toFloat, - (i + 0.75), - BigDecimal(Long.MaxValue.toString + ".12345"), - new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), - new Date(i), - new Timestamp(i)) - } - - ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() - } - - override def afterAll(): Unit = { - ctx.dropTempTable("shuffle") - ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") - super.afterAll() - } - - def checkSerializer[T <: Serializer]( - executedPlan: SparkPlan, - expectedSerializerClass: Class[T]): Unit = { - executedPlan.foreach { - case exchange: Exchange => - val shuffledRDD = exchange.execute() - val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - val serializerNotSetMessage = - s"Expected $expectedSerializerClass as the serializer of Exchange. " + - s"However, the serializer was not set." - val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - val isExpectedSerializer = - serializer.getClass == expectedSerializerClass || - serializer.getClass == classOf[UnsafeRowSerializer] - val wrongSerializerErrorMessage = - s"Expected ${expectedSerializerClass.getCanonicalName} or " + - s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + - s"${serializer.getClass.getCanonicalName} is used." - assert(isExpectedSerializer, wrongSerializerErrorMessage) - case _ => // Ignore other nodes. - } - } - - test("key schema and value schema are not nulls") { - val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - ctx.table("shuffle").collect()) - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = ctx.sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) - } - - test("value schema is null") { - val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert(df.map(r => r.getString(0)).collect().toSeq === - ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) - } - - test("no map output field") { - val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - } - - test("types of fields are all NullTypes") { - // Test range partitioning code path. - val nulls = ctx.sql(s"SELECT null as a, null as b, null as c") - val df = nulls.unionAll(nulls).sort("a") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - df, - Row(null, null, null) :: Row(null, null, null) :: Nil) - - // Test hash partitioning code path. - val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle") - checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - oneRow, - Row(null, null, null)) - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - super.beforeAll() - // Sort merge will not be triggered. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") - } -} - -/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ -class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - - override def beforeAll(): Unit = { - super.beforeAll() - // To trigger the sort merge. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala deleted file mode 100644 index 53de2d0f0771..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.shuffle.ShuffleMemoryManager - -/** - * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. - */ -class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) { - private var oom = false - - override def tryToAcquire(numBytes: Long): Long = { - if (oom) { - oom = false - 0 - } else { - // Uncomment the following to trace memory allocations. - // println(s"tryToAcquire $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - val acquired = super.tryToAcquire(numBytes) - acquired - } - } - - override def release(numBytes: Long): Unit = { - // Uncomment the following to trace memory releases. - // println(s"release $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - super.release(numBytes) - } - - def markAsOutOfMemory(): Unit = { - oom = true - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala deleted file mode 100644 index c7949848513c..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types._ - -/** - * A test suite that generates randomized data to test the [[TungstenSort]] operator. - */ -class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { - - override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - } - - override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) - } - - test("sort followed by limit") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - - test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil - val stringLength = 1024 * 1024 * 2 - checkThatPlansAgree( - Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - - // Test sorting on different data types - for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(TungstenSort.supportsSchema(inputDf.schema)) - checkThatPlansAgree( - inputDf, - plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7c591f6143b9..5a8406789ab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,12 +23,12 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String /** @@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String * * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. */ -class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ @@ -45,23 +48,22 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { private def emptyAggregationBuffer: InternalRow = InternalRow(0) private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null - private var shuffleMemoryManager: TestShuffleMemoryManager = null def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) - assert(leakedShuffleMemory === 0) taskMemoryManager = null } TaskContext.unset() } test(name) { - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - shuffleMemoryManager = new TestShuffleMemoryManager + val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") + memoryManager = new TestMemoryManager(conf) + taskMemoryManager = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -69,7 +71,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) try { f @@ -92,7 +95,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) - assert(!supportsAggregationBufferSchema( + assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) assert( @@ -105,7 +108,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity, PAGE_SIZE_BYTES, false // disable perf metrics @@ -120,7 +122,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -148,7 +149,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -170,18 +170,11 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -193,41 +186,158 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { buf.setInt(0, keyString.length) assert(buf != null) } - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() - withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === - initialMemoryConsumption + 4096 * 16) + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + additionalKeys.zipWithIndex.foreach { case (str, i) => + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) + + if ((i % 100) == 0) { + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) + } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) + } + + assert(out === (keys ++ additionalKeys).sorted) + map.free() + } + + testWithMemoryLeakDetection("test external sorting with an empty map") { + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() while (iter.next()) { - assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) - out += iter.getKey.getString(0) + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) } - assert(out === (keys ++ additionalKeys).sorted) + assert(out === additionalKeys.sorted) + map.free() + } + + testWithMemoryLeakDetection("test external sorting with empty records") { + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + StructType(Nil), + StructType(Nil), + taskMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + (1 to 10).foreach { i => + val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) + assert(buf != null) + } + // Convert the map into a sorter. Right now, it contains one record. + val sorter = map.destructAndCreateExternalSorter() + + // Add more keys to the sorter and make sure the results come out sorted. + (1 to 4096).foreach { i => + map.getAggregationBufferFromUnsafeRow(UnsafeRow.createFromByteArray(0, 0)) + + if ((i % 100) == 0) { + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) + } + } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) + + var count = 0 + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + iter.getKey.copy() + iter.getValue.copy() + count += 1 + } + + // 1 record per map, spilled 42 times. + assert(count === 42) map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + val rand = new Random(42) + for (i <- 1 to 100) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) + } + // Simulate running out of space + memoryManager.limit(0) + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + assert(buf == null) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + map.free() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 0282b25b9dd5..29027a664b4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -20,18 +20,17 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark._ -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite { - +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) @@ -46,6 +45,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { testKVSorter(keySchema, valueSchema, spill = i > 3) } + /** * Create a test case using randomly generated data for the given key and value schema. * @@ -60,95 +60,144 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. */ private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = { + // Create the data converters + val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val kConverter = UnsafeProjection.create(keySchema) + val vConverter = UnsafeProjection.create(valueSchema) + + val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get + val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get + + val inputData = Seq.fill(1024) { + val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) + val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]") val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - partitionId = 0, - taskAttemptId = 98456, - attemptNumber = 0, - taskMemoryManager = taskMemMgr, - metricsSystem = null)) - - // Create the data converters - val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) - val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) - val kConverter = UnsafeProjection.create(keySchema) - val vConverter = UnsafeProjection.create(valueSchema) - - val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get - val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get - - val input = Seq.fill(1024) { - val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) - val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) - (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) - } - - val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024) - - // Insert generated keys and values into the sorter - input.foreach { case (k, v) => - sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) - // 1% chance we will spill - if (rand.nextDouble() < 0.01 && spill) { - shuffleMemMgr.markAsOutOfMemory() - sorter.closeCurrentPage() - } - } + testKVSorter( + keySchema, + valueSchema, + inputData, + pageSize = 16 * 1024 * 1024, + spill + ) + } + } - // Collect the sorted output - val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] - val iter = sorter.sortedIterator() - while (iter.next()) { - out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + /** + * Create a test case using the given input data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the input data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter( + keySchema: StructType, + valueSchema: StructType, + inputData: Seq[(InternalRow, InternalRow)], + pageSize: Long, + spill: Boolean): Unit = { + val memoryManager = + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) + val taskMemMgr = new TaskMemoryManager(memoryManager, 0) + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 98456, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) + + // Insert the keys and values into the sorter + inputData.foreach { case (k, v) => + sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) + // 1% chance we will spill + if (rand.nextDouble() < 0.01 && spill) { + memoryManager.markExecutionAsOutOfMemoryOnce() + sorter.closeCurrentPage() } + } - val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) - val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) - val kvOrdering = new Ordering[(InternalRow, InternalRow)] { - override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { - keyOrdering.compare(x._1, y._1) match { - case 0 => valueOrdering.compare(x._2, y._2) - case cmp => cmp - } + // Collect the sorted output + val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] + val iter = sorter.sortedIterator() + while (iter.next()) { + out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + } + sorter.cleanupResources() + + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) + val kvOrdering = new Ordering[(InternalRow, InternalRow)] { + override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { + keyOrdering.compare(x._1, y._1) match { + case 0 => valueOrdering.compare(x._2, y._2) + case cmp => cmp } } + } - // Testing to make sure output from the sorter is sorted by key - var prevK: InternalRow = null - out.zipWithIndex.foreach { case ((k, v), i) => - if (prevK != null) { - assert(keyOrdering.compare(prevK, k) <= 0, - s""" - |key is not in sorted order: - |previous key: $prevK - |current key : $k - """.stripMargin) - } - prevK = k + // Testing to make sure output from the sorter is sorted by key + var prevK: InternalRow = null + out.zipWithIndex.foreach { case ((k, v), i) => + if (prevK != null) { + assert(keyOrdering.compare(prevK, k) <= 0, + s""" + |key is not in sorted order: + |previous key: $prevK + |current key : $k + """.stripMargin) } + prevK = k + } - // Testing to make sure the key/value in output matches input - assert(out.sorted(kvOrdering) === input.sorted(kvOrdering)) + // Testing to make sure the key/value in output matches input + assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) - // Make sure there is no memory leak - val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory - if (shuffleMemMgr != null) { - val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() - assert(0L === leakedShuffleMemory) - } - assert(0 === leakedUnsafeMemory) - TaskContext.unset() + // Make sure there is no memory leak + assert(0 === taskMemMgr.cleanUpAllAllocatedMemory) + TaskContext.unset() + } + + test("kv sorting with records that exceed page size") { + val pageSize = 128 + + val schema = StructType(StructField("b", BinaryType) :: Nil) + val externalConverter = CatalystTypeConverters.createToCatalystConverter(schema) + val converter = UnsafeProjection.create(schema) + + val rand = new Random() + val inputData = Seq.fill(1024) { + val kBytes = new Array[Byte](rand.nextInt(pageSize)) + val vBytes = new Array[Byte](rand.nextInt(pageSize)) + rand.nextBytes(kBytes) + rand.nextBytes(vBytes) + val k = converter(externalConverter.apply(Row(kBytes)).asInstanceOf[InternalRow]) + val v = converter(externalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) } + + testKVSorter( + schema, + schema, + inputData, + pageSize, + spill = true + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d64..09e258299de5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,20 +17,44 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark._ -class UnsafeRowSerializerSuite extends SparkFunSuite { + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { val converter = UnsafeProjection.create(schema) - converter.apply(internalRow) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } } test("toUnsafeRow() test helper method") { @@ -52,8 +76,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +85,79 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val input = new ClosableByteArrayInputStream(Array.empty) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + var sc: SparkContext = null + var outputFile: File = null + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) + val taskContext = new TaskContextImpl( + 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) + } { + // Clean up + if (sc != null) { + sc.stop() + } + + // restore the spark env + SparkEnv.set(oldEnv) + + if (outputFile != null) { + outputFile.delete() + } + } + } + + test("SPARK-10403: unsafe row serializer with SortShuffleManager") { + val conf = new SparkConf().set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) + .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rowsRDD, + new PartitionIdPassthrough(2), + Some(new UnsafeRowSerializer(2))) + val shuffled = new ShuffledRowRDD(dependency) + shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 58% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 66014ddca059..b2d04f7c5a6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,43 +15,42 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testDecimalColumnStats(InternalRow(null, null, 0)) + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -61,11 +60,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) - assertResult(10, "Wrong null count")(stats.genericGet(2)) - assertResult(20, "Wrong row count")(stats.genericGet(3)) - assertResult(stats.genericGet(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,22 +72,23 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { - val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName - val columnType = FIXED_DECIMAL(15, 10) + val columnStatsName = classOf[DecimalColumnStats].getSimpleName + val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { - val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + val columnStats = new DecimalColumnStats(15, 10) + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val columnStats = new FixedDecimalColumnStats(15, 10) + val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) @@ -96,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) - assertResult(10, "Wrong null count")(stats.genericGet(2)) - assertResult(20, "Wrong row count")(stats.genericGet(3)) - assertResult(stats.genericGet(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala new file mode 100644 index 000000000000..706ff1f99850 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteOrder, ByteBuffer} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkFunSuite} + + +class ColumnTypeSuite extends SparkFunSuite with Logging { + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_TYPE = MAP(MapType(IntegerType, StringType)) + private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType)) + private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil)) + + test("defaultSize") { + val checks = Map( + NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, + STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) + + checks.foreach { case (columnType, expectedSize) => + assertResult(expectedSize, s"Wrong defaultSize for $columnType") { + columnType.defaultSize + } + } + } + + test("actualSize") { + def checkActualSize( + columnType: ColumnType[_], + value: Any, + expected: Int): Unit = { + + assertResult(expected, s"Wrong actualSize for $columnType") { + val row = new GenericMutableRow(1) + row.update(0, CatalystTypeConverters.convertToCatalyst(value)) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + columnType.actualSize(proj(row), 0) + } + } + + checkActualSize(NULL, null, 0) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) + checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) + checkActualSize(ARRAY_TYPE, Array[Any](1), 16) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) + checkActualSize(STRUCT_TYPE, Row("hello"), 28) + } + + testNativeColumnType(BOOLEAN) + testNativeColumnType(BYTE) + testNativeColumnType(SHORT) + testNativeColumnType(INT) + testNativeColumnType(LONG) + testNativeColumnType(FLOAT) + testNativeColumnType(DOUBLE) + testNativeColumnType(COMPACT_DECIMAL(15, 10)) + testNativeColumnType(STRING) + + testColumnType(NULL) + testColumnType(BINARY) + testColumnType(LARGE_DECIMAL(20, 10)) + testColumnType(STRUCT_TYPE) + testColumnType(ARRAY_TYPE) + testColumnType(MAP_TYPE) + + def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = { + testColumnType[T#InternalType](columnType) + } + + def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { + + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) + val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) + + test(s"$columnType append/extract") { + buffer.rewind() + seq.foreach(columnType.append(_, 0, buffer)) + + buffer.rewind() + seq.foreach { row => + logInfo("buffer = " + buffer + ", expected = " + row) + val expected = converter(row.get(0, columnType.dataType)) + val extracted = converter(columnType.extract(buffer)) + assert(expected === extracted, + s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" + + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) + } + } + } + + private def dumpBuffer(buff: ByteBuffer): Any = { + val sb = new StringBuilder() + while (buff.hasRemaining) { + val b = buff.get() + sb.append(Integer.toHexString(b & 0xff)).append(' ') + } + if (sb.nonEmpty) sb.setLength(sb.length - 1) + sb.toString() + } + + test("column type for decimal types with different precision") { + (1 to 18).foreach { i => + assertResult(COMPACT_DECIMAL(i, 0)) { + ColumnType(DecimalType(i, 0)) + } + } + + assertResult(LARGE_DECIMAL(19, 0)) { + ColumnType(DecimalType(19, 0)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 79% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 79bb7d072feb..9cae65ef6f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { @@ -39,21 +41,25 @@ object ColumnarTestUtils { } (columnType match { + case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() - case DATE => Random.nextInt() case LONG => Random.nextLong() - case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case _ => - // Using a random one-element map instead of an arbitrary object - Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) + case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) + case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) + case STRUCT(_) => + new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) + case ARRAY(_) => + new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) + case MAP(_) => + ArrayBasedMapData( + Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 67% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 037e2048a863..25afed25c897 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,25 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + setupTestData() test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -41,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + sqlContext.cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + sqlContext.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -59,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -71,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + sqlContext.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -83,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + sqlContext.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -98,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - ctx.cacheTable("timestamps") + sqlContext.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -110,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + sqlContext.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -148,7 +146,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => @@ -159,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 10000), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -174,23 +172,51 @@ class InMemoryColumnarQuerySuite extends QueryTest { BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), - new Timestamp(i), - (1 to i).toSeq, - (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + new Timestamp(i * 1000000L), + (i to i + 10).toSeq, + (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + sqlContext.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + sqlContext.table("InMemoryCache_different_data_types").collect()) + sqlContext.dropTempTable("InMemoryCache_different_data_types") + } + + test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { + val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val cached = df.cache() + // count triggers the caching action. It should not throw. + cached.count() + + // Make sure, the DataFrame is indeed cached. + assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + + // Check result. + checkAnswer( + cached, + sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + ) + + // Drop the cache. + cached.unpersist() + } + + test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { + val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + data.cache() + assert(data.count() === 10) + assert(data.filter($"s" === "3").count() === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 71% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index f4f6c7649bfa..35dc9a276cef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, @@ -32,18 +33,18 @@ class TestNullableColumnAccessor[JvmType]( object TestNullableColumnAccessor { def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) : TestNullableColumnAccessor[JvmType] = { - // Skips the column type ID - buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) } } class NullableColumnAccessorSuite extends SparkFunSuite { - import ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( - BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnAccessor(_) } @@ -63,19 +64,22 @@ class NullableColumnAccessorSuite extends SparkFunSuite { test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => - builder.appendFrom(randomRow, 0) - builder.appendFrom(nullRow, 0) + builder.appendFrom(proj(randomRow), 0) + builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) + assert(converter(row.get(0, columnType.dataType)) + === converter(randomRow.get(0, columnType.dataType))) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 72% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 241d09ea205e..93be3e16a5ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -35,11 +36,13 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( - BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnBuilder(_) } @@ -48,12 +51,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val dataType = columnType.dataType + val proj = UnsafeProjection.create(Array[DataType](dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } @@ -63,12 +68,11 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) } @@ -78,27 +82,22 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val nullRow = makeNullRow(1) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) - columnBuilder.appendFrom(nullRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) + columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(4, "Wrong null count")(buffer.getInt()) // For null positions (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values + val actual = new GenericMutableRow(new Array[Any](1)) (0 until 4).foreach { _ => - val actual = if (columnType.isInstanceOf[GENERIC]) { - SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) - } else { - columnType.extract(buffer) - } - - assert(actual === randomRow.get(0, columnType.dataType), + columnType.extract(buffer, actual, 0) + assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), "Extracted value didn't equal to the original one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 2c0879927a12..d762f7bfe914 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,48 +15,45 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar - -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - } - - before { - ctx.cacheTable("pruningData") - } - - after { - ctx.uncacheTable("pruningData") + try { + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + sqlContext.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons @@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba4..ccbddef0fad3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d..830ca0294e1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62c..988a577a7b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a4..95642e93ae9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { @@ -100,11 +100,11 @@ class RunLengthEncodingSuite extends SparkFunSuite { } test(s"$RunLengthEncoding with $typeName: simple case") { - skeleton(2, Seq(0 -> 2, 1 ->2)) + skeleton(2, Seq(0 -> 2, 1 -> 2)) } test(s"$RunLengthEncoding with $typeName: run length == 1") { - skeleton(2, Seq(0 -> 1, 1 ->1)) + skeleton(2, Seq(0 -> 1, 1 -> 1)) } test(s"$RunLengthEncoding with $typeName: single long run") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03..5e078f251375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala new file mode 100644 index 000000000000..4cc0a3a9585d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Test cases for various [[JSONOptions]]. + */ +class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + + test("allowComments off") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowComments on") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowComments", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowSingleQuotes off") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowSingleQuotes on") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowUnquotedFieldNames off") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowUnquotedFieldNames on") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowNumericLeadingZeros off") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowNumericLeadingZeros on") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getLong(0) == 18) + } + + // The following two tests are not really working - need to look into Jackson's + // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. + ignore("allowNonNumericNumbers off") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + ignore("allowNonNumericNumbers on") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getDouble(0).isNaN) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 61% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f19f22fca7d5..baa258ad2615 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,27 +15,33 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json -import java.io.StringWriter +import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with TestJsonData { +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.sql - import ctx.implicits._ +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -49,13 +55,15 @@ class JsonSuite extends QueryTest with TestJsonData { val factory = new JsonFactory() def enforceCorrectType(value: Any, dataType: DataType): Any = { val writer = new StringWriter() - val generator = factory.createGenerator(writer) - generator.writeObject(value) - generator.flush() + Utils.tryWithResource(factory.createGenerator(writer)) { generator => + generator.writeObject(value) + generator.flush() + } - val parser = factory.createParser(writer.toString) - parser.nextToken() - JacksonParser.convertField(factory, parser, dataType) + Utils.tryWithResource(factory.createParser(writer.toString)) { parser => + parser.nextToken() + JacksonParser.convertField(factory, parser, dataType) + } } val intNumber: Int = 2147483647 @@ -73,8 +81,6 @@ class JsonSuite extends QueryTest with TestJsonData { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) @@ -150,7 +156,7 @@ class JsonSuite extends QueryTest with TestJsonData { // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DoubleType) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) @@ -219,7 +225,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -238,10 +244,10 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -266,12 +272,12 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -285,7 +291,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) :: + StructField("field2", DecimalType(20, 0), true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -365,7 +371,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -381,12 +387,12 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("num_num_2", DoubleType, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -398,11 +404,9 @@ class JsonSuite extends QueryTest with TestJsonData { checkAnswer( sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: - Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, - new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, - new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("12", null, 21474836470.9, null, null, "true") :: + Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: + Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -425,8 +429,8 @@ class JsonSuite extends QueryTest with TestJsonData { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), - Row(BigDecimal("21474836472.2")) :: - Row(BigDecimal("92233720368547758071.3")) :: Nil + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil ) // Widening to Double @@ -455,7 +459,7 @@ class JsonSuite extends QueryTest with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -508,7 +512,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -532,7 +536,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -560,7 +564,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -578,10 +582,10 @@ class JsonSuite extends QueryTest with TestJsonData { test("jsonFile should be based on JSONRelation") { val dir = Utils.createTempDir() dir.delete() - val path = dir.getCanonicalPath - ctx.sparkContext.parallelize(1 to 100) + val path = dir.getCanonicalFile.toURI.toString + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -591,16 +595,17 @@ class JsonSuite extends QueryTest with TestJsonData { assert( relation.isInstanceOf[JSONRelation], "The DataFrame returned by jsonFile should be based on JSONRelation.") - assert(relation.asInstanceOf[JSONRelation].path === Some(path)) - assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) + assert(relation.asInstanceOf[JSONRelation].options.samplingRatio === (0.49 +- 0.001)) val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + sqlContext.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] - assert(relationWithSchema.path === Some(path)) + assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) - assert(relationWithSchema.samplingRatio > 0.99) + assert(relationWithSchema.options.samplingRatio > 0.99) } test("Loading a JSON dataset from a text file") { @@ -608,10 +613,10 @@ class JsonSuite extends QueryTest with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = sqlContext.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -635,6 +640,136 @@ class JsonSuite extends QueryTest with TestJsonData { ) } + test("Loading a JSON dataset primitivesAsString returns schema with primitive types as strings") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path) + + val expectedSchema = StructType( + StructField("bigInteger", StringType, true) :: + StructField("boolean", StringType, true) :: + StructField("double", StringType, true) :: + StructField("integer", StringType, true) :: + StructField("long", StringType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row("92233720368547758070", + "true", + "1.7976931348623157E308", + "10", + "21474836470", + null, + "this is a simple string.") + ) + } + + test("Loading a JSON dataset primitivesAsString returns complex fields as strings") { + val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1) + + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfBigInteger", ArrayType(StringType, true), true) :: + StructField("arrayOfBoolean", ArrayType(StringType, true), true) :: + StructField("arrayOfDouble", ArrayType(StringType, true), true) :: + StructField("arrayOfInteger", ArrayType(StringType, true), true) :: + StructField("arrayOfLong", ArrayType(StringType, true), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", StringType, true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil), true), true) :: + StructField("struct", StructType( + StructField("field1", StringType, true) :: + StructField("field2", StringType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( + StructField("field1", ArrayType(StringType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + Row("str1", "str2", null) + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from jsonTable"), + Row(Seq(null, null, null, null)) + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + Row("922337203685477580700", "-922337203685477580800", null) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + Row(Seq("1", "2", "3"), Seq("str1", "str2")) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + Row(Seq("1", "2", "3"), Seq("1.1", "2.1", "3.1")) + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + Row("str2", "2.1") + ) + + // Access elements of an array of structs. + checkAnswer( + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + Row( + Row("true", "str1", null), + Row("false", null, null), + Row(null, null, null), + null) + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from jsonTable"), + Row( + Row("true", "92233720368547758070"), + "true", + "92233720368547758070") :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + Row(Seq("4", "5", "6"), Seq("str1", "str2")) + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + Row("5", null) + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() @@ -677,7 +812,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -694,7 +829,7 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -715,7 +850,7 @@ class JsonSuite extends QueryTest with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -743,7 +878,7 @@ class JsonSuite extends QueryTest with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -769,7 +904,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -787,7 +922,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -810,7 +945,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -828,64 +963,60 @@ class JsonSuite extends QueryTest with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = ctx.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val jsonDF = sqlContext.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + } + } } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -931,7 +1062,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -954,7 +1085,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -962,11 +1093,11 @@ class JsonSuite extends QueryTest with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) - primTable.registerTempTable("primativeTable") + val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val primTable = sqlContext.read.json(jsonDF.toJSON) + primTable.registerTempTable("primitiveTable") checkAnswer( - sql("select * from primativeTable"), + sql("select * from primitiveTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -975,8 +1106,8 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = sqlContext.read.json(complexFieldAndType1) + val compTable = sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1040,26 +1171,36 @@ class JsonSuite extends QueryTest with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext + val relation0 = new JSONRelation( + Some(empty), + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + None, + None)(sqlContext) + val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( - "path", - 1.0, + Some(singleRow), Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + None, + None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( - "path", - 0.5, + Some(singleRow), Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + None, + None, + parameters = Map("samplingRatio" -> "0.5"))(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( - "path", - 1.0, - Some(StructType(StructField("b", StringType, true) :: Nil)), - context) + Some(singleRow), + Some(StructType(StructField("b", IntegerType, true) :: Nil)), + None, + None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) + assert(relation0 !== relation1) + assert(!logicalRelation0.sameResult(logicalRelation1), + s"$logicalRelation0 and $logicalRelation1 should be considered not having the same result.") + assert(relation1 === relation2) assert(logicalRelation1.sameResult(logicalRelation2), s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") @@ -1071,38 +1212,256 @@ class JsonSuite extends QueryTest with TestJsonData { assert(relation2 !== relation3) assert(!logicalRelation2.sameResult(logicalRelation3), s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") + + withTempPath(dir => { + val path = dir.getCanonicalFile.toURI.toString + sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + + val d1 = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + + val d2 = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + assert(d1 === d2) + }) } test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema(empty, 1.0, "") + val emptySchema = InferSchema.infer(empty, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try { - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } finally { - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempDir { dir => + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + // order of MapType is not defined + assert(sqlContext.read.parquet(path).count() == 5) + + val df2 = sqlContext.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df2.collect()) + } } } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema(emptyRecords, 1.0, "") + val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } + + test("JSON with Partition") { + def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = { + val p = new File(parent, s"$partName=${partValue.toString}") + rdd.saveAsTextFile(p.getCanonicalPath) + p + } + + withTempPath(root => { + val d1 = new File(root, "d1=1") + // root/dt=1/col1=abc + val p1_col1 = makePartition( + sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abc") + + // root/dt=1/col1=abd + val p2 = makePartition( + sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abd") + + sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + }) + } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("SPARK-12057 additional corrupt records do not throw exceptions") { + // Test if we can query corrupt records. + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("dummy", StringType, true) :: Nil) + + { + // We need to make sure we can infer the schema. + val jsonDF = sqlContext.read.json(additionalCorruptRecords) + assert(jsonDF.schema === schema) + } + + { + val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + jsonDF.registerTempTable("jsonTable") + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT dummy, _unparsed + |FROM jsonTable + """.stripMargin), + Row("test", null) :: + Row(null, """[1,2,3]""") :: + Row(null, """":"test", "a":1}""") :: + Row(null, """42""") :: + Row(null, """ ","ian":"test"}""") :: Nil + ) + } + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 85% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index eb62066ac643..cb61f7eeca0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -36,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -47,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -65,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -80,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -96,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -150,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -158,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -167,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,8 +188,16 @@ trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def additionalCorruptRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"dummy":"test"}""" :: + """[1,2,3]""" :: + """":"test", "a":1}""" :: + """42""" :: + """ ","ian":"test"}""" :: Nil) + def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,5 +205,7 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 000000000000..36b929ee1f40 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit): Unit = { + logInfo( + s"""Writing Avro records with the following Avro schema into Parquet file: + | + |${schema.toString(true)} + """.stripMargin) + + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() + } + + test("required primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroPrimitives](path, AvroPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write( + AvroPrimitives.newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setStringColumn(s"val_$i") + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + }) + } + } + + test("optional primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroOptionalPrimitives](path, AvroOptionalPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = if (i % 3 == 0) { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(null) + .setMaybeIntColumn(null) + .setMaybeLongColumn(null) + .setMaybeFloatColumn(null) + .setMaybeDoubleColumn(null) + .setMaybeBinaryColumn(null) + .setMaybeStringColumn(null) + .build() + } else { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeIntColumn(i) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeFloatColumn(i.toFloat + 0.1f) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeStringColumn(s"val_$i") + .build() + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + if (i % 3 == 0) { + Row.apply(Seq.fill(7)(null): _*) + } else { + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + } + }) + } + } + + test("non-nullable arrays") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroNonNullableArrays](path, AvroNonNullableArrays.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = { + val builder = + AvroNonNullableArrays.newBuilder() + .setStringsColumn(Seq.tabulate(3)(i => s"val_$i").asJava) + + if (i % 3 == 0) { + builder.setMaybeIntsColumn(null).build() + } else { + builder.setMaybeIntsColumn(Seq.tabulate(3)(Int.box).asJava).build() + } + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(i => s"val_$i"), + if (i % 3 == 0) null else Seq.tabulate(3)(identity)) + }) + } + } + + ignore("nullable arrays (parquet-avro 1.7.0 does not properly support this)") { + // TODO Complete this test case after upgrading to parquet-mr 1.8+ + } + + test("SPARK-10136 array of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroArrayOfArray](path, AvroArrayOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroArrayOfArray.newBuilder() + .setIntArraysColumn( + Seq.tabulate(3, 3)((i, j) => i * 3 + j: Integer).map(_.asJava).asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) + }) + } + } + + test("map of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroMapOfArray](path, AvroMapOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroMapOfArray.newBuilder() + .setStringToIntsColumn( + Seq.tabulate(3) { i => + i.toString -> Seq.tabulate(3)(j => i + j: Integer).asJava + }.toMap.asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) + }) + } + } + + test("various complex types") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(Seq.tabulate(3)(j => i + j + m: Integer).asJava) + .setNestedStringColumn(s"val_${i + m}") + .build() + }.asJava + }.toMap.asJava + } + + ParquetAvroCompat + .newBuilder() + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}").asJava) + .setStringToIntColumn(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap.asJava) + .setComplexColumn(makeComplexColumn(i)) + .build() + } + + test("SPARK-9407 Push down predicates involving Parquet ENUM columns") { + import testImplicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala new file mode 100644 index 000000000000..0835bd123049 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.QueryTest + +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) + } + + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { + val fsPath = new Path(path) + val fs = fsPath.getFileSystem(hadoopConfiguration) + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq.asJava + + val footers = + ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true) + footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema + } + + protected def logParquetSchema(path: String): Unit = { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} + """.stripMargin) + } +} + +private[sql] object ParquetCompatibilityTest { + implicit class RecordConsumerDSL(consumer: RecordConsumer) { + def message(f: => Unit): Unit = { + consumer.startMessage() + f + consumer.endMessage() + } + + def group(f: => Unit): Unit = { + consumer.startGroup() + f + consumer.endGroup() + } + + def field(name: String, index: Int)(f: => Unit): Unit = { + consumer.startField(name, index) + f + consumer.endField(name, index) + } + } + + /** + * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet + * records with arbitrary structures. + */ + private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String]) + extends WriteSupport[RecordConsumer => Unit] { + + private var recordConsumer: RecordConsumer = _ + + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, metadata.asJava) + } + + override def write(recordWriter: RecordConsumer => Unit): Unit = { + recordWriter.apply(recordConsumer) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`. + * Records are produced by `recordWriters`. + */ + def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = { + writeDirect(path, schema, Map.empty[String, String], recordWriters: _*) + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path` + * with given user-defined key-value `metadata`. Records are produced by `recordWriters`. + */ + def writeDirect( + path: String, + schema: String, + metadata: Map[String, String], + recordWriters: (RecordConsumer => Unit)*): Unit = { + val messageType = MessageTypeParser.parseMessageType(schema) + val writeSupport = new DirectWriteSupport(messageType, metadata) + val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport) + try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 68% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b6a7c4fbddbd..6178e37d2a58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,17 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -39,8 +39,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, @@ -51,25 +50,33 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters - }.flatten.reduceOption(_ && _) - - assert(maybeAnalyzedPredicate.isDefined) - maybeAnalyzedPredicate.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[ParquetRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } + checker(stripSparkFilter(query), expected) } - - checker(query, expected) } } @@ -109,43 +116,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } - test("filter pushdown - short") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } - } - test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -154,13 +136,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -171,6 +153,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -179,13 +162,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -196,6 +179,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -204,13 +188,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -221,6 +205,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -229,24 +214,26 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } - test("filter pushdown - string") { + // See https://issues.apache.org/jira/browse/SPARK-11153 + ignore("filter pushdown - string") { withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") checkFilterPredicate( '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) @@ -256,24 +243,26 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } - test("filter pushdown - binary") { + // See https://issues.apache.org/jira/browse/SPARK-11153 + ignore("filter pushdown - binary") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes("UTF-8") } withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( @@ -288,20 +277,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) checkBinaryFilterPredicate( '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => @@ -311,9 +300,66 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(path).filter("part = 1"), + sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } } + + test("SPARK-10829: Filter combine partition key and attribute doesn't work in DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), + (2 to 3).map(i => Row(i, i.toString, 1))) + } + } + } + + test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/table1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/table2" + (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) + + // If the "c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. This is a Parquet issue (PARQUET-389). + checkAnswer( + sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), + (1 to 1).map(i => Row(i, i.toString, null))) + } + } + } + + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(stripSparkFilter(df).count == 1) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 53% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b415da5b8c13..0c5d4887ed79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import org.apache.parquet.column.{Encoding, ParquetProperties} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -26,10 +28,10 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} @@ -37,6 +39,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -62,9 +65,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -89,6 +91,33 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } + test("SPARK-11694 Parquet logical types are not being tested properly") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 a(INT_8); + | required int32 b(INT_16); + | required int32 c(DATE); + | required int32 d(DECIMAL(1,0)); + | required int64 e(DECIMAL(10,0)); + | required binary f(UTF8); + | required binary g(ENUM); + | required binary h(DECIMAL(32,0)); + | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + } + } + test("string") { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL @@ -97,16 +126,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } - test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .toDF() - // Parquet doesn't allow column names with spaces, have to add an alias here - .select($"_1" cast decimal as "dec") + testStandardAndLegacyModes("fixed-length decimals") { + def makeDecimalRDD(decimal: DecimalType): DataFrame = { + sqlContext + .range(1000) + // Parquet doesn't allow column names with spaces, have to add an alias here. + // Minus 500 here so that negative decimals are also tested. + .select((('id - 500) / 100.0) cast decimal as 'dec) + .coalesce(1) + } - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { + val combinations = Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37)) + for ((precision, scale) <- combinations) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) @@ -117,7 +148,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() @@ -130,22 +161,22 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("map") { + testStandardAndLegacyModes("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) } - test("array") { + testStandardAndLegacyModes("array") { val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) checkParquetFile(data) } - test("array and double") { + testStandardAndLegacyModes("array and double") { val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) checkParquetFile(data) } - test("struct") { + testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -155,7 +186,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("nested struct with array of array as field") { + testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -165,7 +196,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("nested map with struct as value type") { + testStandardAndLegacyModes("nested map with struct as value type") { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetDataFrame(data) { df => checkAnswer(df, data.map { case Tuple1(m) => @@ -202,16 +233,52 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } + test("SPARK-10113 Support for unsigned Parquet logical types") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c(UINT_32); + |} + """.stripMargin) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val errorMessage = intercept[Throwable] { + sqlContext.read.parquet(path.toString).printSchema() + }.toString + assert(errorMessage.contains("Parquet type not supported")) + } + } + + test("SPARK-11692 Support for Parquet logical types, JSON and BSON (embedded types)") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required binary a(JSON); + | required binary b(BSON); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(StringType, BinaryType) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + } + } + test("compression codec") { - def compressionCodecFor(path: String): String = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) + def compressionCodecFor(path: String, codecName: String): String = { + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConfiguration) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + + assert(codecs.distinct === Seq(codecName)) codecs.head } @@ -221,7 +288,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) + compressionCodecFor(path, codec.name()) } } } @@ -276,16 +343,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("write metadata") { withTempPath { file => val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + val fs = FileSystem.getLocal(hadoopConfiguration) + val schema = StructType.fromAttributes(ScalaReflection.attributesFor[(Int, String)]) + writeMetadata(schema, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val expectedSchema = new CatalystSchemaConverter().convert(schema) + val actualSchema = readFooter(path, hadoopConfiguration).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) expectedSchema.checkContains(actualSchema) @@ -349,13 +415,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withTempPath { location => val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) - - ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, - path, - new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf, extraMetadata) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( @@ -367,12 +429,36 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", + classOf[DirectParquetOutputCommitter].getCanonicalName) + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(hadoopConfiguration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -380,26 +466,27 @@ class ParquetIOSuite extends QueryTest with ParquetTest { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", - classOf[BogusParquetOutputCommitter].getCanonicalName) + classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) try { val message = intercept[SparkException] { @@ -408,8 +495,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { assert(message === "Intentional exception for testing purposes") } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } @@ -425,12 +512,126 @@ class ParquetIOSuite extends QueryTest with ParquetTest { }.toString assert(errorMessage.contains("UnknownHostException")) } + + test("SPARK-7837 Do not close output writer twice when commitTask() fails") { + val clonedConf = new Configuration(hadoopConfiguration) + + // Using a output committer that always fail when committing a task, so that both + // `commitTask()` and `abortTask()` are invoked. + hadoopConfiguration.set( + "spark.sql.parquet.output.committer.class", + classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) + + try { + // Before fixing SPARK-7837, the following code results in an NPE because both + // `commitTask()` and `abortTask()` try to close output writers. + + withTempPath { dir => + val m1 = intercept[SparkException] { + sqlContext.range(1).coalesce(1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m1.contains("Intentional exception for testing purposes")) + } + + withTempPath { dir => + val m2 = intercept[SparkException] { + val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + df.write.partitionBy("a").parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m2.contains("Intentional exception for testing purposes")) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-11044 Parquet writer version fixed as version1 ") { + // For dictionary encoding, Parquet changes the encoding types according to its writer + // version. So, this test checks one of the encoding types in order to ensure that + // the file is written with writer version2. + withTempPath { dir => + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Write a Parquet file with writer version2. + hadoopConfiguration.set(ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_2_0.toString) + + // By default, dictionary encoding is enabled from Parquet 1.2.0 but + // it is enabled just in case. + hadoopConfiguration.setBoolean(ParquetOutputFormat.ENABLE_DICTIONARY, true) + val path = s"${dir.getCanonicalPath}/part-r-0.parquet" + sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + .coalesce(1).write.mode("overwrite").parquet(path) + + val blockMetadata = readFooter(new Path(path), hadoopConfiguration).getBlocks.asScala.head + val columnChunkMetadata = blockMetadata.getColumns.asScala.head + + // If the file is written with version2, this should include + // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY + assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + } finally { + // Manually clear the hadoop configuration for other tests. + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data = sqlContext.range(200).map { i => + if (i.getLong(0) < 150) Row(None) + else Row("a") + } + val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil)) + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + val df2 = sqlContext.read.parquet(path) + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + + test("read dictionary encoded decimals written as INT32") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i32.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + } + + test("read dictionary encoded decimals written as INT64") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i64.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + } + + test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-fixed-len.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + } } -class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { override def commitJob(jobContext: JobContext): Unit = { sys.error("Intentional exception for testing purposes") } } + +class TaskCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitTask(context: TaskAttemptContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala new file mode 100644 index 000000000000..83b65fb419ed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + test("parquet files with different physical schemas but share the same logical schema") { + import ParquetCompatibilityTest._ + + // This test case writes two Parquet files, both representing the following Catalyst schema + // + // StructType( + // StructField( + // "f", + // ArrayType(IntegerType, containsNull = false), + // nullable = false)) + // + // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the + // other one comes with parquet-protobuf style 1-level unannotated primitive field. + withTempDir { dir => + val avroStylePath = new File(dir, "avro-style").getCanonicalPath + val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath + + val avroStyleSchema = + """message avro_style { + | required group f (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin + + writeDirect(avroStylePath, avroStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("array", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + } + } + }) + + logParquetSchema(avroStylePath) + + val protobufStyleSchema = + """message protobuf_style { + | repeated int32 f; + |} + """.stripMargin + + writeDirect(protobufStylePath, protobufStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + }) + + logParquetSchema(protobufStylePath) + + checkAnswer( + sqlContext.read.parquet(dir.getCanonicalPath), + Seq( + Row(Seq(0, 1)), + Row(Seq(2, 3)))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 78% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11..71e9034d9779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String -import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -60,14 +58,101 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check(defaultPartitionName, Literal.create(null, NullType)) } + test("parse invalid partitioned directories") { + // Invalid + var paths = Seq( + "hdfs://host:9000/invalidPath", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello") + + var exception = intercept[AssertionError] { + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) + + // Valid + paths = Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/something=true/table/", + "hdfs://host:9000/path/something=true/table/_temporary", + "hdfs://host:9000/path/something=true/table/a=10/b=20", + "hdfs://host:9000/path/something=true/table/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/something=true/table"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/table=true/", + "hdfs://host:9000/path/table=true/_temporary", + "hdfs://host:9000/path/table=true/a=10/b=20", + "hdfs://host:9000/path/table=true/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/table=true"))) + + // Invalid + paths = Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/path1") + + exception = intercept[AssertionError] { + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) + + // Invalid + // Conflicting directory structure: + // "hdfs://host:9000/tmp/tables/partitionedTable" + // "hdfs://host:9000/tmp/tables/nonPartitionedTable1" + // "hdfs://host:9000/tmp/tables/nonPartitionedTable2" + paths = Seq( + "hdfs://host:9000/tmp/tables/partitionedTable", + "hdfs://host:9000/tmp/tables/partitionedTable/p=1/", + "hdfs://host:9000/tmp/tables/nonPartitionedTable1", + "hdfs://host:9000/tmp/tables/nonPartitionedTable2") + + exception = intercept[AssertionError] { + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/tmp/tables/"))) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) + } + test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) + val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true).get + parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -106,8 +191,17 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } test("parse partitions") { - def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) + def check( + paths: Seq[String], + spec: PartitionSpec, + rootPaths: Set[Path] = Set.empty[Path]): Unit = { + val actualSpec = + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + rootPaths) + assert(actualSpec === spec) } check(Seq( @@ -186,7 +280,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + val actualSpec = + parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + assert(actualSpec === spec) } check(Seq( @@ -467,7 +563,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation) => + case LogicalRelation(relation: ParquetRelation, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") @@ -519,7 +615,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) @@ -544,6 +640,70 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } } + test("SPARK-11678: Partition discovery stops at the root path of the dataset") { + withTempPath { dir => + val tablePath = new File(dir, "key=value") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + + withTempPath { dir => + val path = new File(dir, "key=value") + val tablePath = new File(path, "table") + + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + } + + test("use basePath to specify the root dir of a partitioned table.") { + withTempPath { dir => + val tablePath = new File(dir, "table") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + val twoPartitionsDF = + sqlContext + .read + .option("basePath", tablePath.getCanonicalPath) + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + + checkAnswer(twoPartitionsDF, df.filter("b != 3")) + + intercept[AssertionError] { + sqlContext + .read + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + } + } + } + test("listConflictingPartitionColumns") { def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala new file mode 100644 index 000000000000..98333e58cada --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + test("unannotated array of primitive type") { + checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + } + + test("unannotated array of struct") { + checkAnswer( + readResourceParquetFile("old-repeated-message.parquet"), + Row( + Seq( + Row("First inner", null, null), + Row(null, "Second inner", null), + Row(null, null, "Third inner")))) + + checkAnswer( + readResourceParquetFile("proto-repeated-struct.parquet"), + Row( + Seq( + Row("0 - 1", "0 - 2", "0 - 3"), + Row("1 - 1", "1 - 2", "1 - 3")))) + + checkAnswer( + readResourceParquetFile("proto-struct-with-array-many.parquet"), + Seq( + Row( + Seq( + Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"), + Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))), + Row( + Seq( + Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"), + Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))), + Row( + Seq( + Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"), + Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3"))))) + } + + test("struct with unannotated array") { + checkAnswer( + readResourceParquetFile("proto-struct-with-array.parquet"), + Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) + } + + test("unannotated array of struct with unannotated array") { + checkAnswer( + readResourceParquetFile("nested-array-struct.parquet"), + Seq( + Row(2, Seq(Row(1, Seq(Row(3))))), + Row(5, Seq(Row(4, Seq(Row(6))))), + Row(8, Seq(Row(7, Seq(Row(9))))))) + } + + test("unannotated array of string") { + checkAnswer( + readResourceParquetFile("proto-repeated-string.parquet"), + Seq( + Row(Seq("hello", "world")), + Row(Seq("good", "bye")), + Row(Seq("one", "two", "three")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala new file mode 100644 index 000000000000..f777e973052d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -0,0 +1,609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ + + test("simple select queries") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + } + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + } + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + + test("Enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } + + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + + // Disables the global SQL option for schema merging + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + assertResult(2) { + // Disables schema merging via data source option + sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + } + + assertResult(3) { + // Enables schema merging via data source option + sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + } + } + } + } + + test("SPARK-9119 Decimal should be correctly written into parquet") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) + val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.write.parquet(basePath) + + val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + assert(Decimal("67123.45") === Decimal(decimal)) + } + } + + test("SPARK-10005 Schema merging for nested struct") { + withTempPath { dir => + val path = dir.getCanonicalPath + + def append(df: DataFrame): Unit = { + df.write.mode(SaveMode.Append).parquet(path) + } + + // Note that both the following two DataFrames contain a single struct column with multiple + // nested fields. + append((1 to 2).map(i => Tuple1((i, i))).toDF()) + append((1 to 2).map(i => Tuple1((i, i, i))).toDF()) + + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer( + sqlContext.read.option("mergeSchema", "true").parquet(path), + Seq( + Row(Row(1, 1, null)), + Row(Row(2, 2, null)), + Row(Row(1, 1, 1)), + Row(Row(2, 2, 2)))) + } + } + } + + test("SPARK-10301 requested schema clipping - same schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + } + + test("SPARK-11997 parquet with null partition values") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(1, 3) + .selectExpr("if(id % 2 = 0, null, id) AS n", "id") + .write.partitionBy("n").parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).filter("n is null"), + Row(2, null)) + } + } + + // This test case is ignored because of parquet-mr bug PARQUET-370 + ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(null, null))) + } + } + + test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L, null, null))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, null, null, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(1L, 2L, null))) + } + } + + test("SPARK-10301 requested schema clipping - deeply nested struct") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add( + "a", + ArrayType( + new StructType() + .add("b", LongType, nullable = true) + .add("d", StringType, nullable = true), + containsNull = true), + nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(Seq(Row(0, null))))) + } + } + + test("SPARK-10301 requested schema clipping - out of order") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, 1, null)), + Row(Row(null, 2, 4)))) + } + } + + test("SPARK-10301 requested schema clipping - schema merging") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df1.write.mode(SaveMode.Append).parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + checkAnswer( + sqlContext + .read + .option("mergeSchema", "true") + .parquet(path) + .selectExpr("s.a", "s.b", "s.c"), + Seq( + Row(0, null, 2), + Row(1, 2, 3))) + } + } + + testStandardAndLegacyModes("SPARK-10301 requested schema clipping - UDT") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr( + """NAMED_STRUCT( + | 'f0', CAST(id AS STRING), + | 'f1', NAMED_STRUCT( + | 'a', CAST(id + 1 AS INT), + | 'b', CAST(id + 2 AS LONG), + | 'c', CAST(id + 3.5 AS DOUBLE) + | ) + |) AS s + """.stripMargin) + .coalesce(1) + + df.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("f1", new NestedStructUDT, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(NestedStruct(1, 2L, 3.5D)))) + } + } + + test("expand UDT in StructType") { + val schema = new StructType().add("n", new NestedStructUDT, nullable = true) + val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true) + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in ArrayType") { + val schema = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT, + containsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT().sqlType, + containsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in MapType") { + val schema = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT, + valueContainsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT().sqlType, + valueContainsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } +} + +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + case class NestedStruct(a: Integer, b: Long, c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = + new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(obj: Any): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + obj match { + case n: NestedStruct => + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = { + datum match { + case row: InternalRow => + NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala similarity index 53% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419..60fa81b1ab81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -15,20 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. @@ -36,32 +34,29 @@ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( testName: String, messageType: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testSchema( testName, StructType.fromAttributes(ScalaReflection.attributesFor[T]), messageType, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } protected def testParquetToCatalyst( testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -79,14 +74,13 @@ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -100,10 +94,9 @@ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testCatalystToParquet( testName, @@ -111,8 +104,7 @@ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) testParquetToCatalyst( testName, @@ -120,8 +112,7 @@ abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } } @@ -138,7 +129,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _6; |} """.stripMargin, - binaryAsString = false) + binaryAsString = false, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", @@ -150,7 +143,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | required int64 _4 (INT_64); | optional int32 _5 (DATE); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "string", @@ -159,7 +155,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); |} """.stripMargin, - binaryAsString = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "binary enum as string", @@ -167,7 +165,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional binary _1 (ENUM); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", @@ -177,7 +178,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - standard", @@ -190,7 +194,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - non-standard", @@ -198,11 +204,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - standard", @@ -215,7 +224,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - standard", @@ -229,7 +240,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - non-standard", @@ -242,7 +255,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", @@ -254,7 +270,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - non-standard", @@ -267,7 +285,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array_element { + | optional group array { | required int32 _1; | required double _2; | } @@ -277,7 +295,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - standard", @@ -301,7 +322,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( "optional types", @@ -316,36 +339,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) - - // Parquet files generated by parquet-thrift are already handled by the schema converter, but - // let's leave this test here until both read path and write path are all updated. - ignore("thrift generated parquet schema") { - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchemaInference[( - Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", - """ - |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } - | } - | } - |} - """.stripMargin, - isThriftDerived = true) - } + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) } class ParquetSchemaSuite extends ParquetSchemaTest { @@ -361,8 +357,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on - val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) - val fromJson = ParquetTypesConverter.convertFromString(jsonString) + val fromCaseClassString = StructType.fromString(caseClassString) + val fromJson = StructType.fromString(jsonString) (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) @@ -471,7 +467,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with nullable element type - 2", @@ -487,7 +486,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -500,7 +502,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 2", @@ -513,7 +518,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 3", @@ -524,7 +532,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 element; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 4", @@ -545,7 +556,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", @@ -564,7 +578,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", @@ -583,7 +600,46 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 7 - " + + "parquet-protobuf primitive lists", + new StructType() + .add("f1", ArrayType(IntegerType, containsNull = false), nullable = false), + """message root { + | repeated int32 f1; + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 8 - " + + "parquet-protobuf non-primitive lists", + { + val elementType = + new StructType() + .add("c1", StringType, nullable = true) + .add("c2", IntegerType, nullable = false) + + new StructType() + .add("f1", ArrayType(elementType, containsNull = false), nullable = false) + }, + """message root { + | repeated group f1 { + | optional binary c1 (UTF8); + | required int32 c2; + | } + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST @@ -604,7 +660,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", @@ -616,11 +674,14 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -637,7 +698,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", @@ -651,7 +714,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Parquet Map to Catalyst MapType @@ -672,7 +738,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 2", @@ -689,7 +758,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", @@ -706,7 +778,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -723,7 +798,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 2", @@ -740,7 +818,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", @@ -757,7 +838,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Catalyst MapType to Parquet Map @@ -779,7 +863,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", @@ -796,7 +882,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -814,7 +903,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", @@ -831,7 +922,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ================================= // Tests for conversion for decimals @@ -844,7 +938,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(1, 0)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(8, 3) - standard", @@ -853,7 +949,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(8, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(9, 3) - standard", @@ -862,7 +960,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(9, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(18, 3) - standard", @@ -871,7 +971,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int64 f1 (DECIMAL(18, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(19, 3) - standard", @@ -880,7 +982,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(1, 0) - prior to 1.4.x", @@ -888,7 +992,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(8, 3) - prior to 1.4.x", @@ -896,7 +1003,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(9, 3) - prior to 1.4.x", @@ -904,7 +1014,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(18, 3) - prior to 1.4.x", @@ -912,5 +1025,535 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String): Unit = { + test(s"Clipping - $testName") { + val expected = MessageTypeParser.parseMessageType(expectedSchema) + val actual = CatalystReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + + try { + expected.checkContains(actual) + actual.checkContains(expected) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expected + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + } + } + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 { + | optional int32 f00; + | optional int32 f01; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add("f00", IntegerType, nullable = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional int32 f00; + | } + | optional int32 f1; + |} + """.stripMargin) + + testSchemaClipping( + "parquet-protobuf style array", + + parquetSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional int32 f010; + | optional double f011; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00Type = ArrayType(StringType, containsNull = false) + val f01Type = ArrayType( + new StructType() + .add("f011", DoubleType, nullable = true), + containsNull = false) + + val f0Type = new StructType() + .add("f00", f00Type, nullable = false) + .add("f01", f01Type, nullable = false) + val f1Type = ArrayType(IntegerType, containsNull = true) + + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", f1Type, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional double f011; + | } + | } + | + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-thrift style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-hive style array", + + parquetSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = true), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = true), nullable = true) + + new StructType().add("f0", f0Type, nullable = true) + }, + + expectedSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "2-level list of required struct", + + parquetSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | required int32 f000; + | optional int64 f001; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00ElementType = + new StructType() + .add("f001", LongType, nullable = true) + .add("f002", DoubleType, nullable = false) + + val f00Type = ArrayType(f00ElementType, containsNull = false) + val f0Type = new StructType().add("f00", f00Type, nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | optional int64 f001; + | required double f002; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "empty requested schema", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = new StructType(), + + expectedSchema = "message root {}") + + testSchemaClipping( + "disjoint field sets", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = + new StructType() + .add( + "f0", + new StructType() + .add("f02", FloatType, nullable = true) + .add("f03", DoubleType, nullable = true), + nullable = true), + + expectedSchema = + """message root { + | required group f0 { + | optional float f02; + | optional double f03; + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int32 value_f0; + | required int64 value_f1; + | } + | required int32 value; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int64 value_f1; + | required double value_f2; + | } + | required int32 value; + | } + | } + |} + """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 53% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 64e94056f209..fdd7697c91f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,16 +15,25 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import org.apache.parquet.schema.MessageType + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} + import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +42,8 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => +private[sql] trait ParquetTest extends SQLTestUtils { + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -97,4 +107,58 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => assert(partDir.mkdirs(), s"Couldn't create directory $partDir") partDir } + + protected def writeMetadata( + schema: StructType, path: Path, configuration: Configuration): Unit = { + val parquetSchema = new CatalystSchemaConverter().convert(schema) + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schema.json).asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + + /** + * This is an overloaded version of `writeMetadata` above to allow writing customized + * Parquet schema. + */ + protected def writeMetadata( + parquetSchema: MessageType, path: Path, configuration: Configuration, + extraMetadata: Map[String, String] = Map.empty[String, String]): Unit = { + val extraMetadataAsJava = extraMetadata.asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadataAsJava, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + + protected def readAllFootersWithoutSummaryFiles( + path: Path, configuration: Configuration): Seq[Footer] = { + val fs = path.getFileSystem(configuration) + ParquetFileReader.readAllFootersInParallel(configuration, fs.getFileStatus(path)).asScala.toSeq + } + + protected def readFooter(path: Path, configuration: Configuration): ParquetMetadata = { + ParquetFileReader.readFooter( + configuration, + new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE), + ParquetMetadataConverter.NO_FILTER) + } + + protected def testStandardAndLegacyModes(testName: String)(f: => Unit): Unit = { + test(s"Standard mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { f } + } + + test(s"Legacy mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { f } + } + } + + protected def readResourceParquetFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 000000000000..88a3d878f97f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + import ParquetCompatibilityTest._ + + private val parquetFilePath = + Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetFilePath.toString)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") + + val nonNullablePrimitiveValues = Seq( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + suits(i % 4)) + + val nullablePrimitiveValues = if (i % 3 == 0) { + Seq.fill(nonNullablePrimitiveValues.length)(null) + } else { + nonNullablePrimitiveValues + } + + val complexValues = Seq( + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + + Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*) + }) + } + + test("SPARK-10136 list of primitive list") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // This Parquet schema is translated from the following Thrift schema: + // + // struct ListOfPrimitiveList { + // 1: list> f; + // } + val schema = + s"""message ListOfPrimitiveList { + | required group f (LIST) { + | repeated group f_tuple (LIST) { + | repeated int32 f_tuple_tuple; + | } + | } + |} + """.stripMargin + + writeDirect(path, schema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + } + } + } + } + }, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(4) + rc.addInteger(5) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(6) + rc.addInteger(7) + } + } + } + } + } + } + }) + + logParquetSchema(path) + + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(Seq(Seq(0, 1), Seq(2, 3))), + Row(Seq(Seq(4, 5), Seq(6, 7))))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala new file mode 100644 index 000000000000..914e516613f9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.util.Utils + + +class TextSuite extends QueryTest with SharedSQLContext { + + test("reading text file") { + verifyFrame(sqlContext.read.format("text").load(testFile)) + } + + test("SQLContext.read.text() API") { + verifyFrame(sqlContext.read.text(testFile)) + } + + test("writing") { + val df = sqlContext.read.text(testFile) + + val tempFile = Utils.createTempDir() + tempFile.delete() + df.write.text(tempFile.getCanonicalPath) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + + Utils.deleteRecursively(tempFile) + } + + test("error handling for invalid schema") { + val tempFile = Utils.createTempDir() + tempFile.delete() + + val df = sqlContext.range(2) + intercept[AnalysisException] { + df.write.text(tempFile.getCanonicalPath) + } + + intercept[AnalysisException] { + sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + } + } + + private def testFile: String = { + Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString + } + + /** Verifies data and schema. */ + private def verifyFrame(df: DataFrame): Unit = { + // schema + assert(df.schema == new StructType().add("value", StringType)) + + // verify content + val data = df.collect() + assert(data(0) == Row("This is a test file for the text data source")) + assert(data(1) == Row("1+1")) + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + // scalastyle:off + assert(data(2) == Row("数据砖头")) + // scalastyle:on + assert(data(3) == Row("\"doh\"")) + assert(data.length == 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8ec3985e0036..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,15 +18,11 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } - - test("DataFrame.typeCheck()") { - testData.typeCheck() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 000000000000..5b2998c3c76d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,83 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.joins + +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + protected var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + val sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + } + + override def afterAll(): Unit = { + sqlContext.sparkContext.stop() + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b..e5fd9e277fc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,11 +22,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { @@ -35,7 +37,8 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -45,11 +48,13 @@ class HashedRelationSuite extends SparkFunSuite { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -62,17 +67,19 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -94,5 +101,37 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) + } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 000000000000..2ec17146476f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private lazy val myUpperCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + private lazy val myLowerCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + private lazy val myTestData1 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + private lazy val myTestData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + + { + lazy val left = myTestData1.where("a = 1") + lazy val right = myTestData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + lazy val left = myTestData1.where("a = 1") + lazy val right = myTestData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + + { + lazy val left = Seq((1, Some(0)), (2, None)).toDF("a", "b") + lazy val right = Seq((1, Some(0)), (2, None)).toDF("a", "b") + testInnerJoin( + "inner join, null safe", + left, + right, + () => (left.col("b") <=> right.col("b")).expr, + Seq( + (1, 0, 1, 0), + (2, null, 2, null) + ) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4..9c80714a9af4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,89 +1,202 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} + +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testOuterJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + joinType: JoinType, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3d..3afd762942bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,100 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testLeftSemiJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 000000000000..efc3227dd60d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 000000000000..bbd94d8da2d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class ExpandNodeSuite extends LocalNodeTest { + + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() + } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala new file mode 100644 index 000000000000..4eadce646d37 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 000000000000..c30327185e16 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,141 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.broadcast.TorrentBroadcast +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} + +class HashJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val buildSides = Seq(BuildLeft, BuildRight) + buildSides.foreach { buildSide => + testJoin(buildSide) + } + + /** + * Builds a [[HashedRelation]] based on a resolved `buildKeys` + * and a resolved `buildNode`. + */ + private def buildHashedRelation( + conf: SQLConf, + buildKeys: Seq[Expression], + buildNode: LocalNode): HashedRelation = { + + val buildSideKeyGenerator = UnsafeProjection.create(buildKeys, buildNode.output) + buildNode.prepare() + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + buildNode.close() + + hashedRelation + } + + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin(buildSide: BuildSide): Unit = { + val testNamePrefix = buildSide + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { + val binaryHashJoinNode = + BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) + resolveExpressions(binaryHashJoinNode) + } + val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { + val leftKeys = Seq('id1.attr) + val rightKeys = Seq('id2.attr) + // Figure out the build side and stream side. + val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (node1, leftKeys, node2, rightKeys) + case BuildRight => (node2, rightKeys, node1, leftKeys) + } + // Resolve the expressions of the build side and then create a HashedRelation. + val resolvedBuildNode = resolveExpressions(buildNode) + val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) + val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) + val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]]) + when(broadcastHashedRelation.value).thenReturn(hashedRelation) + + val hashJoinNode = + BroadcastHashJoinNode( + conf, + streamedKeys, + streamedNode, + buildSide, + resolvedBuildNode.output, + broadcastHashedRelation) + resolveExpressions(hashJoinNode) + } + + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + + Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => + val makeUnsafeNode = wrapForUnsafe(makeNode) + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) + } + } + + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 000000000000..c0ad2021b204 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,37 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class IntersectNodeSuite extends LocalNodeTest { + + test("basic") { + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala new file mode 100644 index 000000000000..fb790636a368 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class LimitNodeSuite extends LocalNodeTest { + + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testLimit() + } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 000000000000..0d1ed99eec6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,73 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray + + test("basic open, next, fetch, close") { + val node = new DummyNode(kvIntAttributes, data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { case (k, v) => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyNode(kvIntAttributes, data) + val iter = node.asIterator + node.open() + data.foreach { case (k, v) => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyNode(kvIntAttributes, data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) + node.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala new file mode 100644 index 000000000000..615c41709361 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -0,0 +1,88 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class LocalNodeTest extends SparkFunSuite { + + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) + + /** + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows + */ + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } + + /** + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. + */ + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + + /** + * Resolve all expressions in `expressions` based on the `output` of `localNode`. + * It assumes that all expressions in the `localNode` are resolved. + */ + protected def resolveExpressions( + expressions: Seq[Expression], + localNode: LocalNode): Seq[Expression] = { + require(localNode.expressions.forall(_.resolved)) + val inputMap = localNode.output.map { a => (a.name, a) }.toMap + expressions.map { expression => + expression.transformUp { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 000000000000..45df2ea6552d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,142 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(buildSide, joinType) + } + } + + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin(buildSide: BuildSide, joinType: JoinType): Unit = { + val testNamePrefix = s"$buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) + } + val makeUnsafeNode = wrapForUnsafe(makeNode) + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // ( + // id, name, + // id, nickname + // ) + ( + Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)), + Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)) + ) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala new file mode 100644 index 000000000000..02ecb23d34b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 000000000000..a3e83bbd5145 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + + +class SampleNodeSuite extends LocalNodeTest { + + private def testSample(withReplacement: Boolean): Unit = { + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 000000000000..42ebc7bfcaad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import scala.util.Random + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder + + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) + } + } + + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala new file mode 100644 index 000000000000..666b0235c061 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -0,0 +1,55 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class UnionNodeSuite extends LocalNodeTest { + + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) + } + + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala new file mode 100644 index 000000000000..4339f7260dcb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -0,0 +1,420 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(sparkContext, "long") + val f = () => { + l += 1L + l.add(1L) + } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = sparkContext.accumulator(0L) + val f = () => { l += 1L } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + df.collect() + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + + assert(expectedMetrics.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetrics.keySet) { + val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) + assert(expectedNodeName === actualNodeName) + for (metricName <- expectedMetricsMap.keySet) { + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + } + } + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Project", Map( + "number of rows" -> 2L))) + ) + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("TungstenAggregate metrics") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + + test("BroadcastHashJoin metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + + test("BroadcastHashOuterJoin metrics") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + + test("BroadcastNestedLoopJoin metrics") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + + test("CartesianProduct metrics") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 4L, // right is read twice + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } + } + +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM5) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM5) {} + } + + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + val cl = BoxingFinder.getClassReader(classOfMethodOwner) + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): ClassReader = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala new file mode 100644 index 000000000000..12a4e1356fed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties + +import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.test.SharedSQLContext + +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + private def createTestDataFrame: DataFrame = { + Seq( + (1, 1), + (2, 2) + ).toDF().filter("_1 > 1") + } + + private def createProperties(executionId: Long): Properties = { + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + properties + } + + private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = new StageInfo( + stageId = stageId, + attemptId = attemptId, + // The following fields are not used in tests + name = "", + numTasks = 0, + rddInfos = Nil, + parentIds = Nil, + details = "" + ) + + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( + taskId = taskId, + attemptNumber = attemptNumber, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false + ) + + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { + val metrics = new TaskMetrics + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) + metrics.updateAccumulators() + metrics + } + + test("basic") { + def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { + assert(actual === expected.mapValues(_.toString)) + } + + val listener = new SQLListener(sqlContext.sparkContext.conf) + val executionId = 0 + val df = createTestDataFrame + val accumulatorIds = + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) + // Assume all accumulators are long + var accumulatorValue = 0L + val accumulatorUpdates = accumulatorIds.map { id => + accumulatorValue += 1L + (id, accumulatorValue) + }.toMap + + listener.onOtherEvent(SparkListenerSQLExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) + + val executionUIData = listener.executionIdToData(0) + + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq( + createStageInfo(0, 0), + createStageInfo(1, 0) + ), + createProperties(executionId))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + + assert(listener.getExecutionMetrics(0).isEmpty) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + ))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) + + // Retrying a stage should reset the metrics + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + ))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + + // Ignore the task end for the first attempt + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) + + // Summit a new stage + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + + assert(executionUIData.runningJobs === Seq(0)) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs.isEmpty) + + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + } + + test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { + val listener = new SQLListener(sqlContext.sparkContext.conf) + val executionId = 0 + val df = createTestDataFrame + listener.onOtherEvent(SparkListenerSQLExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { + val listener = new SQLListener(sqlContext.sparkContext.conf) + val executionId = 0 + val df = createTestDataFrame + listener.onOtherEvent(SparkListenerSQLExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + listener.onJobStart(SparkListenerJobStart( + jobId = 1, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 1, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before onJobEnd(JobFailed)") { + val listener = new SQLListener(sqlContext.sparkContext.conf) + val executionId = 0 + val df = createTestDataFrame + listener.onOtherEvent(SparkListenerSQLExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq.empty, + createProperties(executionId))) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobFailed(new RuntimeException("Oops")) + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs === Seq(0)) + } + + test("SPARK-11126: no memory leak when running non SQL jobs") { + val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size + sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + // listener should ignore the non SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + + sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + // listener should save the SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + } + +} + +class SQLListenerMemoryLeakSuite extends SparkFunSuite { + + test("no memory leak") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + SQLContext.clearSqlListener() + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } + } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f..2b91f62c2fa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -177,12 +176,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) } test("SELECT * WHERE (quoted strings)") { @@ -256,26 +257,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("Basic API") { - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -331,9 +332,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -341,8 +342,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -350,7 +351,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -397,7 +398,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -408,18 +409,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Default jdbc dialect registration") { assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } test("quote column names by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val Derby = JdbcDialects.get("jdbc:derby:db") val columns = Seq("abc", "key") val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + val DerbyColumns = columns.map(Derby.quoteIdentifier(_)) assert(MySQLColumns === Seq("`abc`", "`key`")) assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } test("Dialect unregister") { @@ -445,4 +452,49 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("DB2Dialect type mapping") { + val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + } + + test("PostgresDialect type mapping") { + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + } + + test("DerbyDialect jdbc type mapping") { + val derbyDialect = JdbcDialects.get("jdbc:derby:db") + assert(derbyDialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(derbyDialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") + } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + assert(derby.getTableExistsQuery(table) == defaultQuery) + } + + test("Test DataFrame.where for Date and Timestamp") { + // Regression test for bug SPARK-11788 + val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); + val date = java.sql.Date.valueOf("1995-01-01") + val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(0).getAs[java.sql.Timestamp](2) + === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 84b52ca2c733..e23ee6693133 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,12 +22,13 @@ import java.util.Properties import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { + val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -37,10 +38,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -58,14 +55,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -78,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.close() } - private lazy val sc = ctx.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) private lazy val schema2 = StructType( @@ -93,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -144,15 +140,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala deleted file mode 100644 index f98e4acafbf2..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} -import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.test.TestSQLContext - -/** This is the test suite for FilterNullsInJoinKey optimization rule. */ -class FilterNullsInJoinKeySuite extends PlanTest { - - // We add predicate pushdown rules at here to make sure we do not - // create redundant Filter operators. Also, because the attribute ordering of - // the Project operator added by ColumnPruning may be not deterministic - // (the ordering may depend on the testing environment), - // we first construct the plan with expected Filter operators and then - // run the optimizer to add the the Project for column pruning. - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Operator Optimizations", FixedPoint(100), - FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. - CombineFilters, - PushPredicateThroughProject, - BooleanSimplification, - PushPredicateThroughJoin, - PushPredicateThroughGenerate, - ColumnPruning, - ProjectCollapsing) :: Nil - } - - val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) - - val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) - - test("inner join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For an inner join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("make sure we do not keep adding filters") { - val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some('a === 'e)) - .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) - - val optimized = Optimize.execute(joinedPlan.analyze) - val conditions = optimized.collect { - case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs - } - - // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. - assert(conditions.length === 3) - - // Make sure attribtues are indeed a, b, e, i, and j. - assert( - conditions.flatMap(exprs => exprs).toSet === - joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) - } - - test("inner join (partially optimized)") { - val joinCondition = - ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // We cannot extract attribute from the left join key. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("inner join (not optimized)") { - val nonOptimizedJoinConditions = - Some('c - 100 + 'd === 'g + 1 - 'h) :: - Some('d > 'h || 'c === 'g) :: - Some('d + 'g + 'c > 'd - 'h) :: Nil - - nonOptimizedJoinConditions.foreach { joinCondition => - val joinedPlan = - leftRelation - .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) - .select('a, 'c, 'f, 'd, 'h, 'g) - - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - } - - test("left outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left outer join, FilterNullsInJoinKey add filter to the right side. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("right outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a right outer join, FilterNullsInJoinKey add filter to the left side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("full outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, FullOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - // FilterNullsInJoinKey does not fire for a full outer join. - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - - test("left semi join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left semi join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala deleted file mode 100644 index bfa427349ff6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.nio.ByteBuffer -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.fs.Path -import org.apache.parquet.avro.AvroParquetWriter - -import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} - -class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { - import ParquetCompatibilityTest._ - - override val sqlContext: SQLContext = TestSQLContext - - override protected def beforeAll(): Unit = { - super.beforeAll() - - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() - } - - test("Read Parquet file generated by parquet-avro") { - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) - - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - Row( - i % 2 == 0, - i, - i.toLong * 10, - i.toFloat + 0.1f, - i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, - Seq.tabulate(3) { n => - (i + n).toString -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) - } - - def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { - def nullable[T <: AnyRef] = makeNullable[T](i) _ - - def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { - mapAsJavaMap(Seq.tabulate(3) { n => - (i + n).toString -> seqAsJavaList(Seq.tabulate(3) { m => - Nested - .newBuilder() - .setNestedIntsColumn(seqAsJavaList(Seq.tabulate(3)(j => i + j + m))) - .setNestedStringColumn(s"val_${i + m}") - .build() - }) - }.toMap) - } - - ParquetAvroCompat - .newBuilder() - .setBoolColumn(i % 2 == 0) - .setIntColumn(i) - .setLongColumn(i.toLong * 10) - .setFloatColumn(i.toFloat + 0.1f) - .setDoubleColumn(i.toDouble + 0.2d) - .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) - .setStringColumn(s"val_$i") - - .setMaybeBoolColumn(nullable(i % 2 == 0: java.lang.Boolean)) - .setMaybeIntColumn(nullable(i: Integer)) - .setMaybeLongColumn(nullable(i.toLong: java.lang.Long)) - .setMaybeFloatColumn(nullable(i.toFloat + 0.1f: java.lang.Float)) - .setMaybeDoubleColumn(nullable(i.toDouble + 0.2d: java.lang.Double)) - .setMaybeBinaryColumn(nullable(ByteBuffer.wrap(s"val_$i".getBytes))) - .setMaybeStringColumn(nullable(s"val_$i")) - - .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}")) - .setStringToIntColumn( - mapAsJavaMap(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap)) - .setComplexColumn(makeComplexColumn(i)) - - .build() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala deleted file mode 100644 index b4cdfd9e98f6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet -import java.io.File - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.fs.Path -import org.apache.parquet.hadoop.ParquetFileReader -import org.apache.parquet.schema.MessageType -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils - -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - protected var parquetStore: File = _ - - override protected def beforeAll(): Unit = { - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - Utils.deleteRecursively(parquetStore) - } - - def readParquetSchema(path: String): MessageType = { - val fsPath = new Path(path) - val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot(_.getPath.getName.startsWith("_")) - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) - footers.head.getParquetMetadata.getFileMetaData.getSchema - } -} - -object ParquetCompatibilityTest { - def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { - if (i % 3 == 0) null.asInstanceOf[T] else f - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala deleted file mode 100644 index a95f70f2bba6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.hadoop.fs.Path - -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} -import org.apache.spark.util.Utils - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql - - test("simple select queries") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) - checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) - } - sqlContext.catalog.unregisterTable(Seq("tmp")) - } - - test("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) - } - sqlContext.catalog.unregisterTable(Seq("tmp")) - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } - - test("SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) - - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) - } - } - - test("SPARK-6917 DecimalType should work with non-native types") { - val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) - val schema = StructType(List(StructField("d", DecimalType(18, 0), false), - StructField("time", TimestampType, false)).toArray) - withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) - df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) - checkAnswer(df2, df.collect().toSeq) - } - } - - test("Enabling/disabling merging partfiles when merging parquet schema") { - def testSchemaMerging(expectedColumnNumber: Int): Unit = { - withTempDir { dir => - val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - // delete summary files, so if we don't merge part-files, one column will not be included. - Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) - Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) - } - } - - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { - testSchemaMerging(2) - } - - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { - testSchemaMerging(3) - } - } - - test("Enabling/disabling schema merging") { - def testSchemaMerging(expectedColumnNumber: Int): Unit = { - withTempDir { dir => - val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) - } - } - - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { - testSchemaMerging(3) - } - - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { - testSchemaMerging(2) - } - } - - test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { - withTempPath { dir => - val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) - - // Disables the global SQL option for schema merging - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { - assertResult(2) { - // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length - } - - assertResult(3) { - // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala deleted file mode 100644 index 1c532d78790d..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} - -class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { - import ParquetCompatibilityTest._ - - override val sqlContext: SQLContext = TestSQLContext - - private val parquetFilePath = - Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") - - test("Read Parquet file generated by parquet-thrift") { - logInfo( - s"""Schema of the Parquet file written by parquet-thrift: - |${readParquetSchema(parquetFilePath.toString)} - """.stripMargin) - - checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") - - Row( - i % 2 == 0, - i.toByte, - (i + 1).toShort, - i + 2, - i.toLong * 10, - i.toDouble + 0.2d, - // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always - // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume - // Thrift `STRING`s are encoded using UTF-8. - s"val_$i", - s"val_$i", - // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings - suits(i % 4), - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i.toByte: java.lang.Byte), - nullable((i + 1).toShort: java.lang.Short), - nullable(i + 2: Integer), - nullable((i * 10).toLong: java.lang.Long), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i"), - nullable(s"val_$i"), - nullable(suits(i % 4)), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - // Thrift `SET`s are converted to Parquet `LIST`s - Seq(i), - Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, - Seq.tabulate(3) { n => - (i + n) -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85d..6fc9febe4970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,28 +19,30 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { @@ -51,7 +53,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -75,7 +77,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -92,7 +94,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -107,7 +109,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -122,7 +124,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -139,7 +141,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -158,7 +160,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -175,7 +177,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -188,7 +190,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -199,7 +201,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 000000000000..853707c036c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,89 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def shortName(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 84855ce45e91..5f8514e1a241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580..af04079ec895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext +private[sql] abstract class DataSourceTest extends QueryTest { -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3..398b8a1a661c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -21,8 +21,11 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - +import org.apache.spark.unsafe.types.UTF8String class FilteredScanSource extends RelationProvider { override def createRelation( @@ -42,20 +45,44 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("b", IntegerType, nullable = false) :: StructField("c", StringType, nullable = false) :: Nil) + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + def unhandled(filter: Filter): Boolean = { + filter match { + case EqualTo(col, v) => col == "b" + case EqualNullSafe(col, v) => col == "b" + case LessThan(col, v: Int) => col == "b" + case LessThanOrEqual(col, v: Int) => col == "b" + case GreaterThan(col, v: Int) => col == "b" + case GreaterThanOrEqual(col, v: Int) => col == "b" + case In(col, values) => col == "b" + case IsNull(col) => col == "b" + case IsNotNull(col) => col == "b" + case Not(pred) => unhandled(pred) + case And(left, right) => unhandled(left) || unhandled(right) + case Or(left, right) => unhandled(left) || unhandled(right) + case _ => false + } + } + + filters.filter(unhandled) + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase() * 5) + Seq(c * 5 + c.toUpperCase * 5) } FiltersPushed.list = filters + ColumnsRequired.set = requiredColumns.toSet // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v @@ -76,13 +103,15 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case StringStartsWith("c", v) => _.startsWith(v) case StringEndsWith("c", v) => _.endsWith(v) case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) case _ => (c: String) => true } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 - !filters.map(translateFilterOnA(_)(a)).contains(false) && - !filters.map(translateFilterOnC(_)(c)).contains(false) + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => @@ -95,11 +124,16 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +// Used together with `SimpleFilteredScan` to check pushed columns. +object ColumnsRequired { + var set: Set[String] = Set.empty +} - import caseInsensitiveContext.sql +class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered @@ -114,7 +148,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -196,46 +230,80 @@ class FilteredScanSuite extends DataSourceTest { "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) - testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) - - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) - testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) - - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) - - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) - - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b", "c")) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1, Set("a")) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1, Set("b")) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", "c")) + testPushDown( + "SELECT * FROM oneToTenFiltered WHERE b = 1", + 10, + Set("a", "b", "c"), + Set(EqualTo("b", 1))) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5, Set("a", "b", "c")) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0, Set("a", "b", "c")) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0, Set("a", "b", "c")) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0, Set("a", "b", "c")) + + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c")) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c")) + + // Filters referencing multiple columns are not convertible, all referenced columns must be + // required. + testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", "b", "c")) + + // A query with an inconvertible filter, an unhandled filter, and a handled filter. + testPushDown( + """SELECT a + | FROM oneToTenFiltered + | WHERE a + b > 9 + | AND b < 16 + | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') + """.stripMargin.split("\n").map(_.trim).mkString(" "), + 3, + Set("a", "b"), + Set(LessThan("b", 16))) + + def testPushDown( + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String]): Unit = { + testPushDown(sqlString, expectedCount, requiredColumnNames, Set.empty[Filter]) + } - def testPushDown(sqlString: String, expectedCount: Int): Unit = { + def testPushDown( + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String], + expectedUnhandledFilters: Set[Filter]): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { @@ -245,6 +313,15 @@ class FilteredScanSuite extends DataSourceTest { case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") } val rawCount = rawPlan.execute().count() + assert(ColumnsRequired.set === requiredColumnNames) + + val table = caseInsensitiveContext.table("oneToTenFiltered") + val relation = table.queryExecution.logical.collectFirst { + case LogicalRelation(r, _) => r + }.get + + assert( + relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) if (rawCount != expectedCount) { fail( @@ -255,4 +332,3 @@ class FilteredScanSuite extends DataSourceTest { } } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 0b7c46c482c8..5b70d258d6ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,22 +19,18 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var path: File = null - override def beforeAll: Unit = { + override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" @@ -46,10 +42,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } - override def afterAll: Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) + override def afterAll(): Unit = { + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { @@ -110,7 +110,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to less part files. - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" @@ -122,7 +122,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to more part files. - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" @@ -146,27 +146,23 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { caseInsensitiveContext.dropTempTable("jt2") } - test("INSERT INTO not supported for JSONRelation for now") { - intercept[RuntimeException]{ - sql( - s""" - |INSERT INTO TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - } - } - - test("save directly to the path of a JSON table") { - caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") - .write.mode(SaveMode.Overwrite).json(path.toString) + test("INSERT INTO JSONRelation for now") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) checkAnswer( sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i * 5, s"str$i")) + sql("SELECT a, b FROM jt").collect() ) - caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) checkAnswer( sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) + sql("SELECT a, b FROM jt UNION ALL SELECT a, b FROM jt").collect() ) } @@ -183,6 +179,11 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("Caching") { + // write something to the jsonTable + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) // Cached Query Execution caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) @@ -205,9 +206,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. @@ -217,14 +219,15 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) // jsonTable should be recached. assertCached(sql("SELECT * FROM jsonTable")) - // The cached data is the new data. - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a * 2, b FROM jt").collect()) - - // Verify uncaching - caseInsensitiveContext.uncacheTable("jsonTable") - assertCached(sql("SELECT * FROM jsonTable"), 0) + // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation +// // The cached data is the new data. +// checkAnswer( +// sql("SELECT a, b FROM jsonTable"), +// sql("SELECT a * 2, b FROM jt").collect()) +// +// // Verify uncaching +// caseInsensitiveContext.uncacheTable("jsonTable") +// assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala new file mode 100644 index 000000000000..3eaa817f9c0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("write many partitions") { + val path = Utils.createTempDir() + path.delete() + + val df = sqlContext.range(100).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + sqlContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("write many partitions with repeats") { + val path = Utils.createTempDir() + path.delete() + + val base = sqlContext.range(100) + val df = base.unionAll(base).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + sqlContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("partitioned columns should appear at the end of schema") { + withTempPath { f => + val path = f.getAbsolutePath + Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) + assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 0d5183444af7..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 3cbf5467b253..cb6e5179b31f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,14 +22,56 @@ import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } + + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } + + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + } + + test("error message for unknown data sources") { + val error1 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("avro") + } + assert(error1.getMessage.contains("spark-packages")) - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + val error2 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("com.databricks.spark.avro") + } + assert(error2.getMessage.contains("spark-packages")) - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + val error3 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("asfdwefasdfasdf") + } + assert(error3.getMessage.contains("spark-packages")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index b032515a9d28..10d261368993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,21 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -49,27 +45,32 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } - def checkLoad(): Unit = { + def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { caseInsensitiveContext.conf.setConf( SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.load(path.toString), expectedDF.collect()) // Test if we can pick up the data source name passed in load. caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), - sql("SELECT b FROM jsonTable").collect()) + sql(s"SELECT b FROM $tbl").collect()) } test("save with path and load") { @@ -102,7 +103,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { test("save and save again") { df.write.json(path.toString) - var message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { df.write.json(path.toString) }.getMessage @@ -118,12 +119,11 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { df.write.mode(SaveMode.Overwrite).json(path.toString) checkLoad() - message = intercept[RuntimeException] { - df.write.mode(SaveMode.Append).json(path.toString) - }.getMessage + // verify the append mode + df.write.mode(SaveMode.Append).json(path.toString) + val df2 = df.unionAll(df) + df2.registerTempTable("jsonTable2") - assert( - message.contains("Append mode is not supported"), - "We should complain that 'Append mode is not supported' for JSON source.") + checkLoad(df2, "jsonTable2") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index cfb03ff485b7..26c1ff520406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.sources +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String class DefaultSource extends SimpleScanSource @@ -73,7 +72,7 @@ case class AllDataTypesScan( sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", - s"str_$i".getBytes(), + s"str_$i".getBytes(StandardCharsets.UTF_8), i % 2 == 0, i.toByte, i.toShort, @@ -83,22 +82,23 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date(1970, 1, 1), + Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), Row(Seq(s"str_$i", s"str_${i + 1}"), - Row(Seq(new Date(1970, 1, i + 1))))) + Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) } } } -class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -113,18 +113,20 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date(1970, 1, 1), + Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, @@ -280,7 +285,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) + (1 to 10).map(i => Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))).toSeq) test("Caching") { // Cached Query Execution @@ -305,9 +310,10 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala similarity index 55% rename from core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala index b3b281ff465f..152c9c8459de 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -15,22 +15,23 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.sql.test -private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString: String = { - connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId - } -} +import java.io.{IOException, InputStream} + +import scala.sys.process.BasicIO -private[nio] object ConnectionId { +object ProcessTestUtils { + class ProcessOutputCapturer(stream: InputStream, capture: String => Unit) extends Thread { + this.setDaemon(true) - def createConnectionIdFromString(connectionIdString: String): ConnectionId = { - val res = connectionIdString.split("_").map(_.trim()) - if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + - " to a ConnectionId Object") + override def run(): Unit = { + try { + BasicIO.processFully(capture)(stream) + } catch { case _: IOException => + // Ignores the IOException thrown when the process termination, which closes the input + // stream abruptly. + } } - new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 000000000000..83c63e04f344 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val emptyTestData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("emptyTestData") + df + } + + protected lazy val testData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + emptyTestData + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + courseSales + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec..e87da1527c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -21,15 +21,79 @@ import java.io.File import java.util.UUID import scala.util.Try +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils -trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def sparkContext = sqlContext.sparkContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.sqlContext + + // This must live here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def hadoopConfiguration: Configuration = { + sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -114,4 +178,70 @@ trait SQLTestUtils { this: SparkFunSuite => sqlContext.sql(s"USE $db") try f finally sqlContext.sql(s"USE default") } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(sqlContext, plan) + } +} + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..e7b376548787 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 000000000000..c89a1516503e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) + } + + protected[sql] override lazy val conf: SQLConf = new SQLConf { + + clear() + + override def clear(): Unit = { + super.clear() + + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } + } + + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def sqlContext: SQLContext = self + } +} + +private[sql] object TestSQLContext { + + /** + * A map used to store all confs that need to be overridden in sql/core unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala new file mode 100644 index 000000000000..b46b0d2f6040 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.sql.{functions, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import functions._ + + test("execute callback functions when a DataFrame action finished successfully") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += ((funcName, qe, duration)) + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j") + df.select("i").collect() + df.filter($"i" > 0).count() + + assert(metrics.length == 2) + + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3 > 0) + + assert(metrics(1)._1 == "count") + assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) + assert(metrics(1)._3 > 0) + + sqlContext.listenerManager.unregister(listener) + } + + test("execute callback functions when a DataFrame action failed") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + metrics += ((funcName, qe, exception)) + } + + // Only test failed case here, so no need to implement `onSuccess` + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + } + sqlContext.listenerManager.register(listener) + + val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } + val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") + + // Ignore the log when we are expecting an exception. + sparkContext.setLogLevel("FATAL") + val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + + assert(metrics.length == 1) + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3.getMessage == e.getMessage) + + sqlContext.listenerManager.unregister(listener) + } + + test("get numRows metrics by callback") { + val metrics = ArrayBuffer.empty[Long] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += qe.executedPlan.longMetric("numInputRows").value.value + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + + assert(metrics.length == 3) + assert(metrics(0) == 1) + assert(metrics(1) == 1) + assert(metrics(2) == 2) + + sqlContext.listenerManager.unregister(listener) + } + + // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never + // updated, we can filter it out later. However, when we aggregate(sum) accumulator values at + // driver side for SQL physical operators, these -1 values will make our result smaller. + // A easy fix is to create a new SQLMetric(including new MetricValue, MetricParam, etc.), but we + // can do it later because the impact is just too small (1048576 tasks for 1 MB). + ignore("get size metrics by callback") { + val metrics = ArrayBuffer.empty[Long] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += qe.executedPlan.longMetric("dataSize").value.value + val bottomAgg = qe.executedPlan.children(0).children(0) + metrics += bottomAgg.longMetric("dataSize").value.value + } + } + sqlContext.listenerManager.register(listener) + + val sparkListener = new SaveInfoListener + sqlContext.sparkContext.addSparkListener(sparkListener) + + val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") + df.groupBy("i").count().collect() + + def getPeakExecutionMemory(stageId: Int): Long = { + val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables + .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + + assert(peakMemoryAccumulator.size == 1) + peakMemoryAccumulator.head._2.value.toLong + } + + assert(sparkListener.getCompletedStageInfos.length == 2) + val bottomAggDataSize = getPeakExecutionMemory(0) + val topAggDataSize = getPeakExecutionMemory(1) + + // For this simple case, the peakExecutionMemory of a stage should be the data size of the + // aggregate operator, as we only have one memory consuming operator per stage. + assert(metrics.length == 2) + assert(metrics(0) == topAggDataSize) + assert(metrics(1) == bottomAggDataSize) + + sqlContext.listenerManager.unregister(listener) + } +} diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-avro.sh similarity index 76% rename from sql/core/src/test/scripts/gen-code.sh rename to sql/core/src/test/scripts/gen-avro.sh index 5d8d8ad08555..48174b287fd7 100755 --- a/sql/core/src/test/scripts/gen-code.sh +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -22,10 +22,9 @@ cd - rm -rf $BASEDIR/gen-java mkdir -p $BASEDIR/gen-java -thrift\ - --gen java\ - -out $BASEDIR/gen-java\ - $BASEDIR/thrift/parquet-compat.thrift - -avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr -avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 000000000000..ada432c68ab9 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift index fa5ed8c62306..98bf778aec5d 100644 --- a/sql/core/src/test/thrift/parquet-compat.thrift +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -15,7 +15,7 @@ * limitations under the License. */ -namespace java org.apache.spark.sql.parquet.test.thrift +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift enum Suit { SPADES, diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 73e6ccdb1eaf..b5b2143292a6 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -60,21 +60,42 @@ ${hive.group} hive-jdbc + + ${hive.group} + hive-service + ${hive.group} hive-beeline + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + com.sun.jersey + jersey-server + org.seleniumhq.selenium selenium-java test - - - io.netty - netty - - + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala new file mode 100644 index 000000000000..2228f651e238 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.server + +import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} + +/** + * Class to upgrade a package-private class to public, and + * implement a `process()` operation consistent with + * the behavior of older Hive versions + * @param serverName name of the hive server + */ +private[apache] class HiveServerServerOptionsProcessor(serverName: String) + extends ServerOptionsProcessor(serverName) { + + def process(args: Array[String]): Boolean = { + // A parse failure automatically triggers a system exit + val response = super.parse(args) + val executor = response.getServerOptionsExecutor() + // return true if the parsed option was to start the service + executor.isInstanceOf[StartOptionExecutor] + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index b7db80d93f85..a4fd0c3ce970 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -24,7 +27,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} @@ -32,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -52,7 +55,6 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) @@ -65,7 +67,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { - val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) } @@ -73,7 +75,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } @@ -90,6 +92,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) @@ -149,16 +157,26 @@ object HiveThriftServer2 extends Logging { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - var onlineSessionNum: Int = 0 - val sessionList = new mutable.LinkedHashMap[String, SessionInfo] - val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] - val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) - val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - var totalRunning = 0 - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -170,13 +188,15 @@ object HiveThriftServer2 extends Logging { } def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { - val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) - sessionList.put(sessionId, info) - onlineSessionNum += 1 - trimSessionIfNecessary() + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } } - def onSessionClosed(sessionId: String): Unit = { + def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 trimSessionIfNecessary() @@ -187,7 +207,7 @@ object HiveThriftServer2 extends Logging { sessionId: String, statement: String, groupId: String, - userName: String = "UNKNOWN"): Unit = { + userName: String = "UNKNOWN"): Unit = synchronized { val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) @@ -197,27 +217,29 @@ object HiveThriftServer2 extends Logging { totalRunning += 1 } - def onStatementParsed(id: String, executionPlan: String): Unit = { + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED } def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { - executionList(id).finishTimestamp = System.currentTimeMillis - executionList(id).detail = errorMessage - executionList(id).state = ExecutionState.FAILED - totalRunning -= 1 - trimExecutionIfNecessary() + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } } - def onStatementFinish(id: String): Unit = { + def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 trimExecutionIfNecessary() } - private def trimExecutionIfNecessary() = synchronized { + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => @@ -226,7 +248,7 @@ object HiveThriftServer2 extends Logging { } } - private def trimSessionIfNecessary() = synchronized { + private def trimSessionIfNecessary() = { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => @@ -241,9 +263,12 @@ object HiveThriftServer2 extends Logging { private[hive] class HiveThriftServer2(hiveContext: HiveContext) extends HiveServer2 with ReflectedCompositeService { + // state is tracked internally so that the server only attempts to shut down if it successfully + // started, and then once only. + private val started = new AtomicBoolean(false) override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(this, hiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) @@ -259,8 +284,19 @@ private[hive] class HiveThriftServer2(hiveContext: HiveContext) } private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { - val transportMode: String = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.equalsIgnoreCase("http") + val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) + transportMode.toLowerCase(Locale.ENGLISH).equals("http") } + + override def start(): Unit = { + super.start() + started.set(true) + } + + override def stop(): Unit = { + if (started.getAndSet(false)) { + super.stop() + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8758887ff3a..e022ee86a763 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,20 +20,15 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Map => JMap, UUID} +import java.util.{Arrays, UUID, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli._ -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -41,7 +36,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} private[hive] class SparkExecuteStatementOperation( @@ -58,6 +53,18 @@ private[hive] class SparkExecuteStatementOperation( private var dataTypes: Array[DataType] = _ private var statementId: String = _ + private lazy val resultSchema: TableSchema = { + if (result == null || result.queryExecution.analyzed.output.size == 0) { + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) + } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema.asJava) + } + } + def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. hiveContext.sparkContext.clearJobGroup() @@ -125,56 +132,36 @@ private[hive] class SparkExecuteStatementOperation( } } - def getResultSetSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } + def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { - val parentSessionState = SessionState.get() - val hiveConf = getConfigForOperation() - val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - val sessionHive = getCurrentHive() - val currentSqlSession = hiveContext.currentSession + val sparkServiceUGI = Utils.getUGI() // Runnable impl to call runInternal asynchronously, // from a different thread val backgroundOperation = new Runnable() { override def run(): Unit = { - val doAsAction = new PrivilegedExceptionAction[Object]() { - override def run(): Object = { - - // User information is part of the metastore client member in Hive - hiveContext.setSession(currentSqlSession) - Hive.set(sessionHive) - SessionState.setCurrentSessionState(parentSessionState) + val doAsAction = new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) log.error("Error running hive query: ", e) } - return null } } try { - ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + sparkServiceUGI.doAs(doAsAction) } catch { case e: Exception => setOperationException(new HiveSQLException(e)) @@ -186,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( try { // This submit blocks if no background threads are available to run this operation val backgroundHandle = - getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation) setBackgroundHandle(backgroundHandle) } catch { case rejected: RejectedExecutionException => @@ -201,10 +188,15 @@ private[hive] class SparkExecuteStatementOperation( } } - private def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + HiveThriftServer2.listener.onStatementStart( statementId, parentSession.getSessionHandle.getSessionId.toString, @@ -274,43 +266,4 @@ private[hive] class SparkExecuteStatementOperation( } } } - - /** - * If there are query specific settings to overlay, then create a copy of config - * There are two cases we need to clone the session config that's being passed to hive driver - * 1. Async query - - * If the client changes a config setting, that shouldn't reflect in the execution - * already underway - * 2. confOverlay - - * The query specific settings should only be applied to the query config and not session - * @return new configuration - * @throws HiveSQLException - */ - private def getConfigForOperation(): HiveConf = { - var sqlOperationConf = getParentSession().getHiveConf() - if (!getConfOverlay().isEmpty() || runInBackground) { - // clone the partent session config for this query - sqlOperationConf = new HiveConf(sqlOperationConf) - - // apply overlay query specific settings, if any - getConfOverlay().foreach { case (k, v) => - try { - sqlOperationConf.verifyAndSet(k, v) - } catch { - case e: IllegalArgumentException => - throw new HiveSQLException("Error applying statement specific settings", e) - } - } - } - return sqlOperationConf - } - - private def getCurrentHive(): Hive = { - try { - return Hive.get() - } catch { - case e: HiveException => - throw new HiveSQLException("Failed to get current Hive object", e); - } - } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index f66a17b20915..8e7aa75bc3b2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io._ -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, Locale} + +import org.apache.spark.sql.AnalysisException + +import scala.collection.JavaConverters._ -import jline.{ConsoleReader, History} +import jline.console.ConsoleReader +import jline.console.history.FileHistory import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory @@ -38,8 +41,12 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') @@ -76,7 +83,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { val cliConf = new HiveConf(classOf[SessionState]) // Override the location of the metastore since this is only used for local execution. - HiveContext.newTemporaryConfiguration().foreach { + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false).foreach { case (key, value) => cliConf.set(key, value) } val sessionState = new CliSessionState(cliConf) @@ -96,9 +103,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf - sessionState.cmdProperties.entrySet().foreach { item => - val key = item.getKey.asInstanceOf[String] - val value = item.getValue.asInstanceOf[String] + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString // We do not propagate metastore options to the execution copy of hive. if (key != "javax.jdo.option.ConnectionURL") { conf.set(key, value) @@ -109,18 +116,11 @@ private[hive] object SparkSQLCLIDriver extends Logging { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } + val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. - if (sessionState.getHost != null) { - sessionState.connect() - if (sessionState.isRemoteMode) { - prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt - continuedPrompt = "".padTo(prompt.length, ' ') - } - } - - if (!sessionState.isRemoteMode) { + if (!remoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -131,6 +131,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } val cli = new SparkSQLCLIDriver @@ -170,15 +173,16 @@ private[hive] object SparkSQLCLIDriver extends Logging { val reader = new ConsoleReader() reader.setBellEnabled(false) + reader.setExpandEvents(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) - CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) val historyDirectory = System.getProperty("user.home") try { if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + ".hivehistory" - reader.setHistory(new History(new File(historyFile))) + reader.setHistory(new FileHistory(new File(historyFile))) } else { logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") @@ -190,10 +194,28 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + ShutdownHookManager.addShutdownHook { () => + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + + // TODO: missing +/* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") clientTransportTSocketField.setAccessible(true) transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] +*/ + transport = null var ret = 0 var prefix = "" @@ -230,6 +252,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(ret) } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -239,25 +268,33 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver // because the Hive unit tests do not go through the main() code path. - if (!sessionState.isRemoteMode) { + if (!isRemoteMode) { SparkSQLEnv.init() + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - if (cmd_trimmed.toLowerCase.equals("quit") || - cmd_trimmed.toLowerCase.equals("exit") || - tokens(0).equalsIgnoreCase("source") || + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit") || + tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || - sessionState.isRemoteMode) { + isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() @@ -267,7 +304,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(tokens, hconf) if (proc != null) { // scalastyle:off println @@ -277,6 +314,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out + val err = sessionState.err val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) @@ -287,7 +325,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = rc.getResponseCode if (ret != 0) { - console.printError(rc.getErrorMessage()) + // For analysis exception, only the error is printed out to the console. + rc.getException() match { + case e : AnalysisException => + err.println(s"""Error in query: ${e.getMessage}""") + case _ => err.println(rc.getErrorMessage()) + } driver.close() return ret } @@ -296,15 +339,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { // Print the column names. - Option(driver.getSchema.getFieldSchemas).map { fields => - out.println(fields.map(_.getName).mkString("\t")) + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) } } var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach{ l => + res.asScala.foreach { l => counter += 1 out.println(l) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 41f647d5f8c5..5ad8c54f296d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,36 +21,37 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConverters._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ +import org.apache.hive.service.server.HiveServer2 import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import scala.collection.JavaConversions._ - -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) - extends CLIService +private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) + extends CLIService(hiveServer) with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims.isSecurityEnabled) { + if (UserGroupInformation.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + sparkServiceUGI = Utils.getUGI() setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => @@ -75,7 +76,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => def initCompositeService(hiveConf: HiveConf) { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") - serviceList.foreach(_.init(hiveConf)) + serviceList.asScala.foreach(_.init(hiveConf)) // Emulating `AbstractService.init(hiveConf)` invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 77272aecf283..f1ec7238520a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{ArrayList => JArrayList, List => JList} +import java.util.{Arrays, ArrayList => JArrayList, List => JList} +import org.apache.log4j.LogManager +import org.apache.spark.sql.AnalysisException + +import scala.collection.JavaConverters._ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +31,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import scala.collection.JavaConversions._ - private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver @@ -43,14 +45,14 @@ private[hive] class SparkSQLDriver( private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + if (analyzed.output.isEmpty) { + new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null) } else { val fieldSchemas = analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new Schema(fieldSchemas, null) + new Schema(fieldSchemas.asJava, null) } } @@ -63,9 +65,12 @@ private[hive] class SparkSQLDriver( tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { - case cause: Throwable => - logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) + case ae: AnalysisException => + logDebug(s"Failed in [$command]", ae) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae) + case cause: Throwable => + logError(s"Failed in [$command]", cause) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause) } } @@ -79,7 +84,7 @@ private[hive] class SparkSQLDriver( if (hiveResponse == null) { false } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava) hiveResponse = null true } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 1d41c4613182..bacf6cc458fd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext @@ -64,7 +64,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 2d5ee6800228..de4e9c62b57a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -25,21 +25,27 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.SessionHandle import org.apache.hive.service.cli.session.SessionManager import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.hive.service.server.HiveServer2 import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager +private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveContext) + extends SessionManager(hiveServer) with ReflectedCompositeService { - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( @@ -55,15 +61,23 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) protocol: TProtocolVersion, username: String, passwd: String, + ipAddress: String, sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = + super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, + delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + val ctx = if (hiveContext.hiveThriftServerSingleSession) { + hiveContext + } else { + hiveContext.newSession() + } + ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle } @@ -71,7 +85,6 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() + sparkSqlOperationManager.sessionToContexts.remove(sessionHandle) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index c8031ed0f343..476651a559d2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,20 +30,21 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) +private[thriftserver] class SparkSQLOperationManager() extends OperationManager with Logging { val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") val sessionToActivePool = Map[SessionHandle, String]() + val sessionToContexts = Map[SessionHandle, HiveContext]() override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - + val hiveContext = sessionToContexts(parentSession.getSessionHandle) val runInBackground = async && hiveContext.hiveThriftServerAsync val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(hiveContext, sessionToActivePool) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 10c83d8b27a2..e990bd06011f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -39,14 +39,16 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - generateBasicStats() ++ -
        ++ -

        - {listener.onlineSessionNum} session(s) are online, - running {listener.totalRunning} SQL statement(s) -

        ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
        ++ +

        + {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

        ++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } @@ -65,11 +67,11 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(): Seq[Node] = { - val numStatement = listener.executionList.size + val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.executionList.values + val dataRows = listener.getExecutionList def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -136,15 +138,15 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = - listener.sessionList.values + val dataRows = sessionList val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/sql/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 3b01afa603ce..af16cb31df18 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -40,21 +40,22 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val sessionStat = listener.sessionList.find(stat => { - stat._1 == parameterId - }).getOrElse(null) - require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") val content = - generateBasicStats() ++ -
        ++ -

        - User {sessionStat._2.userName}, - IP {sessionStat._2.ip}, - Session created at {formatDate(sessionStat._2.startTimestamp)}, - Total run {sessionStat._2.totalExecution} SQL -

        ++ - generateSQLStatsTable(sessionStat._2.sessionId) + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
        ++ +

        + User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

        ++ + generateSQLStatsTable(sessionStat.sessionId) + } UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } @@ -73,13 +74,13 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(sessionID: String): Seq[Node] = { - val executionList = listener.executionList - .filter(_._2.sessionId == sessionID) + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -146,10 +147,11 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { val dataRows = - listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + sessionList.sortBy(_.startTimestamp).reverse.map ( session => Seq( session.userName, session.ip, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 94fd8a6bb60b..4eabeaa6735e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,9 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { - override val name = "SQL" + override val name = "JDBC/ODBC Server" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index df80d04b4080..fcf039916913 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,86 +18,127 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.sql.Timestamp +import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() + val scratchDirPath = Utils.createTempDir() - before { - warehousePath.delete() - metastorePath.delete() + override def beforeAll(): Unit = { + super.beforeAll() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } - after { + override def afterAll(): Unit = { + try { warehousePath.delete() metastorePath.delete() + scratchDirPath.delete() + } finally { + super.afterAll() + } } + /** + * Run a CLI operation and expect all the queries and expected answers to be returned. + * @param timeout maximum time for the commands to complete + * @param extraArgs any extra arguments + * @param errorResponses a sequence of strings whose presence in the stdout of the forked process + * is taken as an immediate error condition. That is: if a line containing + * with one of these strings is found, fail the test immediately. + * The default value is `Seq("Error:")` + * + * @param queriesAndExpectedAnswers one or more tupes of query + answer + */ def runCliWithin( timeout: FiniteDuration, - extraArgs: Seq[String] = Seq.empty)( + extraArgs: Seq[String] = Seq.empty, + errorResponses: Seq[String] = Seq("Error:"))( queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queriesString = queries.map(_ + "\n").mkString val command = { + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local + | --driver-java-options -Dderby.system.durability=test + | --conf spark.ui.enabled=false | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. - val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - buffer += s"$source> $line" + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. + buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" + // If we haven't found all expected answers and another expected answer comes up... - if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) { next += 1 // If all expected answers have been found... if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } + } else { + errorResponses.foreach { r => + if (line.contains(r)) { + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with error line '$line'")) + } + } } } - // Searching expected output line from both stdout and stderr of the CLI process - val process = (Process(command, None) #< queryStream).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + val process = new ProcessBuilder(command: _*).start() + + val stdinWriter = new OutputStreamWriter(process.getOutputStream) + stdinWriter.write(queriesString) + stdinWriter.flush() + stdinWriter.close() + + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => - logError( + val message = s""" |======================= |CliSuite failure output |======================= |Spark SQL CLI command line: ${command.mkString(" ")} - | + |Exception: $cause |Executed query $next "${queries(next)}", |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | @@ -105,8 +146,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { |=========================== |End CliSuite failure output |=========================== - """.stripMargin, cause) - throw cause + """.stripMargin + logError(message, cause) + fail(message, cause) } finally { process.destroy() } @@ -124,7 +166,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" -> "OK", "CACHE TABLE hive_test;" - -> "Time taken: ", + -> "", "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" @@ -145,7 +187,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" - -> "Time taken: " + -> "hive_test" ) runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( @@ -175,7 +217,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" - -> "Time taken:", + -> "", "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" @@ -184,4 +226,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "OK" ) } + + test("SPARK-11188 Analysis error reporting") { + runCliWithin(timeout = 2.minute, + errorResponses = Seq("AnalysisException"))( + "select * from nonexistent_table;" + -> "Error in query: Table not found: nonexistent_table;" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 39b31523e07c..139d8e897ba1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} -import scala.concurrent.ExecutionContext.Implicits.global -import scala.sys.process.{Process, ProcessLogger} +import scala.io.Source import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -41,9 +41,10 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} object TestData { def getTestDataFilePath(name: String): URL = { @@ -205,6 +206,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { import org.apache.spark.sql.SQLConf var defaultV1: String = null var defaultV2: String = null + var data: ArrayBuffer[Int] = null withMultipleConnectionJdbcStatement( // create table @@ -214,10 +216,16 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { "DROP TABLE IF EXISTS test_map", "CREATE TABLE test_map(key INT, value STRING)", s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", - "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC", + "CREATE DATABASE db1") queries.foreach(statement.execute) + val plan = statement.executeQuery("explain select * from test_table") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") val buf1 = new collection.mutable.ArrayBuffer[Int]() while (rs1.next()) { @@ -233,6 +241,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() assert(buf1 === buf2) + + data = buf1 }, // first session, we get the default value of the session status @@ -289,56 +299,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() }, - // accessing the cached data in another session + // try to access the cached data in another session { statement => - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) + // Cached temporary table can't be accessed by other sessions + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") } - rs1.close() - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) + val plan = statement.executeQuery("explain select key from test_map ORDER BY key DESC") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + + val rs = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf = new collection.mutable.ArrayBuffer[Int]() + while (rs.next()) { + buf += rs.getInt(1) } - rs2.close() + rs.close() + assert(buf === data) + }, - assert(buf1 === buf2) - statement.executeQuery("UNCACHE TABLE test_table") + // switch another database + { statement => + statement.execute("USE db1") - // TODO need to figure out how to determine if the data loaded from cache - val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf3 = new collection.mutable.ArrayBuffer[Int]() - while (rs3.next()) { - buf3 += rs3.getInt(1) + // there is no test_map table in db1 + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") } - rs3.close() - assert(buf1 === buf3) + statement.execute("CREATE TABLE test_map2(key INT, value STRING)") }, - // accessing the uncached table + // access default database { statement => - // TODO need to figure out how to determine if the data loaded from cache - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) - } - rs1.close() - - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) + // current database should still be `default` + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map2") } - rs2.close() - assert(buf1 === buf2) + statement.execute("USE db1") + // access test_map2 + statement.executeQuery("SELECT key from test_map2") } ) } @@ -378,6 +383,184 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() } } + + test("test add jar") { + withMultipleConnectionJdbcStatement( + { + statement => + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + statement.executeQuery(s"ADD JAR $jarFile") + }, + + { + statement => + val queries = Seq( + "DROP TABLE IF EXISTS smallKV", + "CREATE TABLE smallKV(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE smallKV", + "DROP TABLE IF EXISTS addJar", + """CREATE TABLE addJar(key string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + + queries.foreach(statement.execute) + + statement.executeQuery( + """ + |INSERT INTO TABLE addJar SELECT 'k1' as key FROM smallKV limit 1 + """.stripMargin) + + val actualResult = + statement.executeQuery("SELECT key FROM addJar") + val actualResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (actualResult.next()) { + actualResultBuffer += actualResult.getString(1) + } + actualResult.close() + + val expectedResult = + statement.executeQuery("SELECT 'k1'") + val expectedResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (expectedResult.next()) { + expectedResultBuffer += expectedResult.getString(1) + } + expectedResult.close() + + assert(expectedResultBuffer === actualResultBuffer) + + statement.executeQuery("DROP TABLE IF EXISTS addJar") + statement.executeQuery("DROP TABLE IF EXISTS smallKV") + } + ) + } + + test("Checks Hive version via SET -v") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET -v") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } + + test("Checks Hive version via SET") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } + + test("SPARK-11595 ADD JAR with input path having URL scheme") { + withJdbcStatement { statement => + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + Seq( + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + + val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs1.next()) + assert(rs1.getString(1) === "Function: udtf_count2") + + assert(rs1.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs1.getString(1) + } + + assert(rs1.next()) + assert(rs1.getString(1) === "Usage: To be added.") + + val dataPath = "../hive/src/test/resources/data/files/kv1.txt" + + Seq( + s"CREATE TABLE test_udtf(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" + ).foreach(statement.execute) + + val rs2 = statement.executeQuery( + "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + } + } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } +} + +class SingleSessionSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.binary + + override protected def extraConf: Seq[String] = + "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil + + test("test single session") { + withMultipleConnectionJdbcStatement( + { statement => + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + // Configurations and temporary functions added in this session should be visible to all + // the other sessions. + Seq( + "SET foo=bar", + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + }, + + { statement => + val rs1 = statement.executeQuery("SET foo") + + assert(rs1.next()) + assert(rs1.getString(1) === "foo") + assert(rs1.getString(2) === "bar") + + val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs2.next()) + assert(rs2.getString(1) === "Function: udtf_count2") + + assert(rs2.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs2.getString(1) + } + + assert(rs2.next()) + assert(rs2.getString(1) === "Usage: To be added.") + } + ) + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -466,10 +649,13 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + protected def extraConf: Seq[String] = Nil + protected def serverStartCommand(port: Int) = { val portConf = if (mode == ServerMode.binary) { ConfVars.HIVE_SERVER2_THRIFT_PORT @@ -483,7 +669,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl val tempLog4jConf = Utils.createTempDir().getCanonicalPath Files.write( - """log4j.rootCategory=INFO, console + """log4j.rootCategory=DEBUG, console |log4j.appender.console=org.apache.log4j.ConsoleAppender |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout @@ -492,7 +678,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl new File(s"$tempLog4jConf/log4j.properties"), UTF_8) - tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + tempLog4jConf } s"""$startScript @@ -501,18 +687,36 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false + | ${extraConf.mkString("\n")} """.stripMargin.split("\\s+").toSeq } + /** + * String to scan for when looking for the the thrift binary endpoint running. + * This can change across Hive versions. + */ + val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" + + /** + * String to scan for when looking for the the thrift HTTP endpoint running. + * This can change across Hive versions. + */ + val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" + + val SERVER_STARTUP_TIMEOUT = 3.minutes + private def startThriftServer(port: Int, attempt: Int) = { warehousePath = Utils.createTempDir() warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -528,45 +732,59 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") - val env = Seq( - // Disables SPARK_TESTING to exclude log4j.properties in test directories. - "SPARK_TESTING" -> "0", - // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started - // at a time, which is not Jenkins friendly. - "SPARK_PID_DIR" -> pidDir.getCanonicalPath) - - logPath = Process(command, None, env: _*).lines.collectFirst { - case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) - }.getOrElse { - throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + logPath = { + val lines = Utils.executeAndGetOutput( + command = command, + extraEnvironment = Map( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be + // started at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath), + redirectStderr = true) + + lines.split("\n").collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } } val serverStarted = Promise[Unit]() // Ensures that the following "tail" command won't fail. logPath.createNewFile() - logTailingProcess = + val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) + + logTailingProcess = { + val command = s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}".split(" ") // Using "-n +0" to make sure all lines in the log file are checked. - Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( - (line: String) => { - diagnosisBuffer += line + val builder = new ProcessBuilder(command: _*) + val captureOutput = (line: String) => diagnosisBuffer.synchronized { + diagnosisBuffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { + successLines.foreach { r => + if (line.contains(r)) { serverStarted.trySuccess(()) - } else if (line.contains("HiveServer2 is stopped")) { - // This log line appears when the server fails to start and terminates gracefully (e.g. - // because of port contention). - serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) } - })) + } + } + + val process = builder.start() - Await.result(serverStarted.future, 2.minute) + new ProcessOutputCapturer(process.getInputStream, captureOutput).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput).start() + process + } + + Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue() + Utils.executeAndGetOutput( + command = Seq(stopScript), + extraEnvironment = Map("SPARK_PID_DIR" -> pidDir.getCanonicalPath)) Thread.sleep(3.seconds.toMillis) warehousePath.delete() @@ -575,6 +793,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de45..bf431cd6b026 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala deleted file mode 100644 index 1a5ba20404c4..000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import java.io.File - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Runs the test cases that are included in the hive distribution with hash joins. - */ -class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) - } - - override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) - super.afterAll() - } - - override def whiteList = Seq( - "auto_join0", - "auto_join1", - "auto_join10", - "auto_join11", - "auto_join12", - "auto_join13", - "auto_join14", - "auto_join14_hadoop20", - "auto_join15", - "auto_join17", - "auto_join18", - "auto_join19", - "auto_join2", - "auto_join20", - "auto_join21", - "auto_join22", - "auto_join23", - "auto_join24", - "auto_join25", - "auto_join26", - "auto_join27", - "auto_join28", - "auto_join3", - "auto_join30", - "auto_join31", - "auto_join32", - "auto_join4", - "auto_join5", - "auto_join6", - "auto_join7", - "auto_join8", - "auto_join9", - "auto_join_filters", - "auto_join_nulls", - "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", - "correlationoptimizer1", - "correlationoptimizer10", - "correlationoptimizer11", - "correlationoptimizer13", - "correlationoptimizer14", - "correlationoptimizer15", - "correlationoptimizer2", - "correlationoptimizer3", - "correlationoptimizer4", - "correlationoptimizer6", - "correlationoptimizer7", - "correlationoptimizer8", - "correlationoptimizer9", - "join0", - "join1", - "join10", - "join11", - "join12", - "join13", - "join14", - "join14_hadoop20", - "join15", - "join16", - "join17", - "join18", - "join19", - "join2", - "join20", - "join21", - "join22", - "join23", - "join24", - "join25", - "join26", - "join27", - "join28", - "join29", - "join3", - "join30", - "join31", - "join32", - "join32_lessSize", - "join33", - "join34", - "join35", - "join36", - "join37", - "join38", - "join39", - "join4", - "join40", - "join41", - "join5", - "join6", - "join7", - "join8", - "join9", - "join_1to1", - "join_array", - "join_casesensitive", - "join_empty", - "join_filters", - "join_hive_626", - "join_map_ppr", - "join_nulls", - "join_nullsafe", - "join_rc", - "join_reorder2", - "join_reorder3", - "join_reorder4", - "join_star" - ) - - // Only run those query tests in the realWhileList (do not try other ignored query files). - override def testCases: Seq[(String, File)] = super.testCases.filter { - case (name, _) => realWhiteList.contains(name) - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 53d5b22b527b..2d0d7b8af358 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.tags.ExtendedHiveTest /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( @@ -50,6 +53,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + RuleExecutor.resetTime() } override def afterAll() { @@ -58,6 +62,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -266,8 +273,42 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive returns string from UTC formatted timestamp, spark returns timestamp type "date_udf", + // Can't compare the result that have newline in it + "udf_get_json_object", + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this - "udf7" + "udf7", + + // Trivial changes to DDL output + "compute_stats_empty_table", + "compute_stats_long", + "create_view_translate", + "show_create_table_serde", + "show_tblproperties", + + // Odd changes to output + "merge4", + + // Thift is broken... + "inputddl8", + + // Hive changed ordering of ddl: + "varchar_union1", + + // Parser changes in Hive 1.2 + "input25", + "input26", + + // Uses invalid table name + "innerjoin", + + // classpath problems + "compute_stats.*", + "udf_bitmap_.*", + + // The difference between the double numbers generated by Hive and Spark + // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) + "udaf_corr" ) /** @@ -430,7 +471,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_orderby1", "escape_sortby1", "explain_rearrange", - "fetch_aggregation", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", @@ -621,6 +661,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_star", "lateral_view", "lateral_view_cp", + "lateral_view_noalias", "lateral_view_ppd", "leftsemijoin", "leftsemijoin_mr", @@ -647,6 +688,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_file_with_space_in_the_name", "loadpart1", "louter_join_ppr", + "macro", "mapjoin_distinct", "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", @@ -820,7 +862,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_cast_1", "type_widening", "udaf_collect_set", - "udaf_corr", "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af..98bbdf0653c2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -454,6 +454,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) """.stripMargin, reset = false) + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 15. testExpressions", s""" |select p_mfgr,p_name, p_size, @@ -472,7 +475,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 16. testMultipleWindows", s""" |select p_mfgr,p_name, p_size, @@ -530,6 +533,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // when running this test suite under Java 7 and 8. // We change the original sql query a little bit for making the test suite passed // under different JDK + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 20. testSTATs", """ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp @@ -547,7 +553,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |) t lateral view explode(uniq_size) d as uniq_data |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index b00f320318be..d96f3e2b9f62 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -36,6 +36,11 @@ + + + com.twitter + parquet-hadoop-bundle + org.apache.spark spark-core_${scala.binary.version} @@ -54,31 +59,45 @@ ${project.version} - ${hive.group} - hive-metastore + org.apache.spark + spark-test-tags_${scala.binary.version} + - org.codehaus.jackson - jackson-mapper-asl + ${hive.group} + hive-exec + ${hive.group} - hive-serde + hive-metastore + org.apache.avro @@ -92,13 +111,57 @@ ${avro.mapred.classifier} - org.scalacheck - scalacheck_${scala.binary.version} - test + commons-httpclient + commons-httpclient + + + org.apache.calcite + calcite-avatica - junit - junit + org.apache.calcite + calcite-core + + + org.apache.httpcomponents + httpclient + + + org.codehaus.jackson + jackson-mapper-asl + + + + commons-codec + commons-codec + + + joda-time + joda-time + + + org.jodd + jodd-core + + + com.google.code.findbugs + jsr305 + + + org.datanucleus + datanucleus-core + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.scalacheck + scalacheck_${scala.binary.version} test diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..4a774fbf1fdf --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 110f51a30586..0eeb62ca2cb3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp +import java.util.concurrent.TimeUnit +import java.util.regex.Pattern -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions @@ -29,34 +31,53 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.util.VersionInfo -import org.apache.spark.Logging -import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql._ +import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} -import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} +import org.apache.spark.sql.execution.datasources.{ResolveDataSource, DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} +import org.apache.spark.sql.execution.ui.SQLListener +import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkContext} /** * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext */ -private[hive] class HiveQLDialect extends ParserDialect { +private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect { override def parse(sqlText: String): LogicalPlan = { - HiveQl.parseSql(sqlText) + sqlContext.executionHive.withHiveState { + HiveQl.parseSql(sqlText) + } + } +} + +/** + * Returns the current database of metadataHive. + */ +private[hive] case class CurrentDatabase(ctx: HiveContext) + extends LeafExpression with CodegenFallback { + override def dataType: DataType = StringType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = { + UTF8String.fromString(ctx.metadataHive.currentDatabase) } } @@ -66,13 +87,40 @@ private[hive] class HiveQLDialect extends ParserDialect { * * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { +class HiveContext private[hive]( + sc: SparkContext, + cacheManager: CacheManager, + listener: SQLListener, + @transient private val execHive: ClientWrapper, + @transient private val metaHive: ClientInterface, + isRootContext: Boolean) + extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { self => - import HiveContext._ + def this(sc: SparkContext) = { + this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), null, null, true) + } + def this(sc: JavaSparkContext) = this(sc.sc) + + import org.apache.spark.sql.hive.HiveContext._ logDebug("create HiveContext") + /** + * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF, + * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader + * and Hive client (both of execution and metadata) with existing HiveContext. + */ + override def newSession(): HiveContext = { + new HiveContext( + sc = sc, + cacheManager = cacheManager, + listener = listener, + execHive = executionHive.newSession(), + metaHive = metadataHive.newSession(), + isRootContext = false) + } + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive @@ -108,8 +156,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * this does not necessarily need to be the same version of Hive that is used internally by * Spark SQL for execution. */ - protected[hive] def hiveMetastoreVersion: String = - getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) + protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) /** * The location of the jars that should be used to instantiate the HiveMetastoreClient. This @@ -144,6 +191,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) + protected[hive] def hiveThriftServerSingleSession: Boolean = + sc.conf.get("spark.sql.hive.thriftServer.singleSession", "false").toBoolean + @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -155,14 +205,28 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient - protected[hive] lazy val executionHive: ClientWrapper = { + protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { + execHive + } else { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") - new ClientWrapper( + val loader = new IsolatedClientLoader( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - config = newTemporaryConfiguration(), - initClassLoader = Utils.getContextOrSparkClassLoader) + execJars = Seq(), + config = newTemporaryConfiguration(useInMemoryDerby = true), + isolationOn = false, + baseClassLoader = Utils.getContextOrSparkClassLoader) + loader.createClient().asInstanceOf[ClientWrapper] } - SessionState.setCurrentSessionState(executionHive.state) + + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + private[sql] def defaultOverrides() = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + defaultOverrides() /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. @@ -170,14 +234,20 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * in the hive-site.xml file. */ @transient - protected[hive] lazy val metadataHive: ClientInterface = { + protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { + metaHive + } else { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options // into the isolated client loader val metadataConf = new HiveConf() + + val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) + // `configure` goes second to override other settings. - val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure + val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure val isolatedLoader = if (hiveMetastoreJars == "builtin") { if (hiveExecutionVersion != hiveMetastoreVersion) { @@ -185,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } // We recursively find all jars in the class loader chain, @@ -218,7 +288,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = @@ -239,7 +314,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { .map(_.toURI.toURL) logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using $jars") + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + + s"using ${jars.mkString(":")}") new IsolatedClientLoader( version = metaVersion, execJars = jars.toSeq, @@ -248,7 +324,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) } - isolatedLoader.client + isolatedLoader.createClient() } protected[sql] override def parseSql(sql: String): LogicalPlan = { @@ -267,12 +343,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -284,9 +361,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * * @since 1.2.0 */ - @Experimental def analyze(tableName: String) { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val tableIdent = SqlParser.parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) relation match { case relation: MetastoreRelation => @@ -298,10 +375,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. + val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDir) { - fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum } else { fileStatus.getLen } @@ -347,8 +435,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } } - protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf - override def setConf(key: String, value: String): Unit = { super.setConf(key, value) executionHive.runSqlHive(s"SET $key=$value") @@ -365,7 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { setConf(entry.key, entry.stringConverter(value)) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog @@ -373,7 +459,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this.executionHive) + + // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer + // can't access the SessionState of metadataHive. + functionRegistry.registerFunction( + "current_database", + (expressions: Seq[Expression]) => new CurrentDatabase(this)) /* An analyzer that uses the Hive metastore. */ @transient @@ -384,55 +476,94 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUDFs :: - ResolveHiveWindowFunction :: PreInsertCastAndRename :: - Nil + (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) override val extendedCheckRules = Seq( PreWriteCheck(catalog) ) } - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = Map.empty - - protected[hive] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - } + protected def configure(): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } - /** - * SQLConf and HiveConf contracts: - * - * 1. reuse existing started SessionState if any - * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the - * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be - * set in the SQLConf *as well as* in the HiveConf. - */ - protected[hive] lazy val sessionState: SessionState = { - var state = SessionState.get() - if (state == null) { - state = new SessionState(new HiveConf(classOf[SessionState])) - SessionState.start(state) - } - state - } + /** + * SQLConf and HiveConf contracts: + * + * 1. create a new SessionState for each HiveContext + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. + */ + @transient + protected[hive] lazy val hiveconf: HiveConf = { + val c = executionHive.conf + setConf(c.getAllProperties) + c + } - protected[hive] lazy val hiveconf: HiveConf = { - setConf(sessionState.getConf.getAllProperties) - sessionState.getConf - } + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") { - classOf[HiveQLDialect].getCanonicalName - } else { - super.dialectClassName + protected[sql] override def getSQLDialect(): ParserDialect = { + if (conf.dialect == "hiveql") { + new HiveQLDialect(this) + } else { + super.getSQLDialect() + } } @transient @@ -449,20 +580,24 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HiveTableScans, DataSinks, Scripts, - HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, + BroadcastNestedLoop, CartesianProduct, - BroadcastNestedLoopJoin + DefaultJoin ) } + private def functionOrMacroDDLPattern(command: String) = Pattern.compile( + ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL).matcher(command) + protected[hive] def runSqlHive(sql: String): Seq[String] = { - if (sql.toLowerCase.contains("create temporary function")) { + val command = sql.trim.toLowerCase + if (functionOrMacroDDLPattern(command).matches()) { executionHive.runSqlHive(sql) - } else if (sql.trim.toLowerCase.startsWith("set")) { + } else if (command.startsWith("set")) { metadataHive.runSqlHive(sql) executionHive.runSqlHive(sql) } else { @@ -475,7 +610,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { /** Extends QueryExecution with hive specific features. */ protected[sql] class QueryExecution(logicalPlan: LogicalPlan) - extends super.QueryExecution(logicalPlan) { + extends org.apache.spark.sql.execution.QueryExecution(this, logicalPlan) { /** * Returns the result as a hive compatible sequence of strings. For native commands, the @@ -493,10 +628,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { .mkString("\t") } case command: ExecutedCommand => - command.executeCollect().map(_(0).toString) + command.executeCollect().map(_.getString(0)) case other => - val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. @@ -510,24 +645,45 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { case _ => super.simpleString } } + + protected[sql] override def addJar(path: String): Unit = { + // Add jar to Hive and classloader + executionHive.addJar(path) + metadataHive.addJar(path) + Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) + super.addJar(path) + } } private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "0.13.1" + val hiveExecutionVersion: String = "1.2.1" + + val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of the Hive metastore. Available options are " + + s"0.12.0 through $hiveExecutionVersion.") + + val HIVE_EXECUTION_VERSION = stringConf( + key = "spark.sql.hive.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of Hive used internally by Spark SQL.") - val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), - doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + - " property can be one of three options: " + - "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + - "-Phive is enabled. When this option is chosen, " + - "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + - "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + - "3. A classpath in the standard format for both Hive and Hadoop.") - + doc = s""" + | Location of the jars that should be used to instantiate the HiveMetastoreClient. + | This property can be one of three options: " + | 1. "builtin" + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | -Phive is enabled. When this option is chosen, + | spark.sql.hive.metastore.version must be either + | ${hiveExecutionVersion} or not defined. + | 2. "maven" + | Use Hive jars of specified version downloaded from Maven repositories. + | 3. A classpath in the standard format for both Hive and Hadoop. + """.stripMargin) val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", defaultValue = Some(true), doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + @@ -564,21 +720,40 @@ private[hive] object HiveContext { doc = "TODO") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ - def newTemporaryConfiguration(): Map[String, String] = { + def newTemporaryConfiguration(useInMemoryDerby: Boolean): Map[String, String] = { + val withInMemoryMode = if (useInMemoryDerby) "memory:" else "" + val tempDir = Utils.createTempDir() - val localMetastore = new File(tempDir, "metastore").getAbsolutePath + val localMetastore = new File(tempDir, "metastore") val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. HiveConf.ConfVars.values().foreach { confvar => - if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { - propMap.put(confvar.varname, confvar.defaultVal) + if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo") + || confvar.varname.contains("hive.metastore.rawstore.impl")) { + propMap.put(confvar.varname, confvar.getDefaultExpr()) } } - propMap.put("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) + propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:${withInMemoryMode};databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 39d798d072ae..95b57d6ad124 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -26,14 +28,11 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: @@ -51,8 +50,8 @@ import scala.collection.JavaConversions._ * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[org.apache.spark.sql.types.MapData]] - * List: [[org.apache.spark.sql.types.ArrayData]] + * Map: [[MapData]] + * List: [[ArrayData]] * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. @@ -62,6 +61,7 @@ import scala.collection.JavaConversions._ * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -76,6 +76,7 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -94,7 +95,8 @@ import scala.collection.JavaConversions._ * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -138,6 +140,7 @@ import scala.collection.JavaConversions._ * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -260,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -290,13 +295,13 @@ private[hive] trait HiveInspectors { DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data - val map = mi.getWritableConstantValue - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = mi.getWritableConstantValue.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - val values = li.getWritableConstantValue + val values = li.getWritableConstantValue.asScala .map(unwrap(_, li.getListElementObjectInspector)) .toArray new GenericArrayData(values) @@ -304,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -342,7 +351,7 @@ private[hive] trait HiveInspectors { case li: ListObjectInspector => Option(li.getList(data)) .map { l => - val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray new GenericArrayData(values) } .orNull @@ -351,15 +360,16 @@ private[hive] trait HiveInspectors { if (map == null) { null } else { - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = map.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) } // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - InternalRow.fromSeq( - allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) + InternalRow.fromSeq(allRefs.asScala.map( + r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -370,28 +380,58 @@ private[hive] trait HiveInspectors { protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => (o: Any) => - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.size) + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) + } else { + null + } + + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => + if (o != null) { + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + } else { + null + } case _: JavaDateObjectInspector => - (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + } else { + null + } case _: JavaTimestampObjectInspector => - (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + } else { + null + } case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] - val wrappers = soi.getAllStructFieldRefs.zip(schema.fields).map { case (ref, field) => - wrapperFor(ref.getFieldObjectInspector, field.dataType) + val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { + case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) } (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { @@ -516,7 +556,7 @@ private[hive] trait HiveInspectors { // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { // 2. set the property for the pojo val tpe = structType(i).dataType x.setStructFieldData( @@ -531,9 +571,9 @@ private[hive] trait HiveInspectors { val fieldRefs = x.getAllStructFieldRefs val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { val tpe = structType(i).dataType result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 @@ -691,10 +731,10 @@ private[hive] trait HiveInspectors { def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { + StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) + )) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( @@ -702,6 +742,10 @@ private[hive] trait HiveInspectors { inspectorToDataType(m.getMapValueObjectInspector)) case _: WritableStringObjectInspector => StringType case _: JavaStringObjectInspector => StringType + case _: WritableHiveVarcharObjectInspector => StringType + case _: JavaHiveVarcharObjectInspector => StringType + case _: WritableHiveCharObjectInspector => StringType + case _: JavaHiveCharObjectInspector => StringType case _: WritableIntObjectInspector => IntegerType case _: JavaIntObjectInspector => IntegerType case _: WritableDoubleObjectInspector => DoubleType diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a8c9b4fa71b9..08b291e08823 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ @@ -30,19 +32,70 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} +private[hive] case class HiveSerDe( + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) + +private[hive] object HiveSerDe { + /** + * Get the Hive SerDe information from the data source abbreviation string or classname. + * + * @param source Currently the source abbreviation can be one of the following: + * SequenceFile, RCFile, ORC, PARQUET, and case insensitive. + * @param hiveConf Hive Conf + * @return HiveSerDe associated with the specified source + */ + def sourceToSerDe(source: String, hiveConf: HiveConf): Option[HiveSerDe] = { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + + val key = source.toLowerCase match { + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s + } + + serdeMap.get(key) + } +} + private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -51,10 +104,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive /** Usages should lock on `this`. */ protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) - // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ - case class QualifiedTableName(database: String, name: String) { - def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) + case class QualifiedTableName(database: String, name: String) + + private def getQualifiedTableName(tableIdent: TableIdentifier) = { + QualifiedTableName( + tableIdent.database.getOrElse(client.currentDatabase).toLowerCase, + tableIdent.table.toLowerCase) + } + + private def getQualifiedTableName(hiveTable: HiveTable) = { + QualifiedTableName( + hiveTable.specifiedDatabase.getOrElse(client.currentDatabase).toLowerCase, + hiveTable.name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -81,6 +143,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + def partColsFromParts: Option[Seq[String]] = { + table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols => + (0 until numPartCols.toInt).map { index => + val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull + if (partCol == null) { + throw new AnalysisException( + "Could not read partitioned columns from the metastore because it is corrupted " + + s"(missing part $index of the it, $numPartCols parts are expected).") + } + + partCol + } + } + } + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. // After SPARK-6024, we removed this flag. // Although we are not using spark.sql.sources.schema any more, we need to still support. @@ -93,7 +170,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // We only need names at here since userSpecifiedSchema we loaded from the metastore // contains partition columns. We can always get datatypes of partitioning columns // from userSpecifiedSchema. - val partitionColumns = table.partitionColumns.map(_.name) + val partitionColumns = partColsFromParts.getOrElse(Nil) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -123,56 +200,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) + def invalidateTable(tableIdent: TableIdentifier): Unit = { + cachedDataSourceTables.invalidate(getQualifiedTableName(tableIdent)) } - val caseSensitive: Boolean = false - - /** - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - */ def createDataSourceTable( - tableName: String, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String], - isExternal: Boolean): Unit = { - createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), - userSpecifiedSchema, - partitionColumns, - provider, - options, - isExternal) - } - - private def createDataSourceTable( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = { - val database = tableIdent.database.getOrElse(client.currentDatabase) - processDatabaseAndTableName(database, tableIdent.table) - } + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - val tableProperties = new scala.collection.mutable.HashMap[String, String] + val tableProperties = new mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) // Saves optional user specified schema. Serialized JSON schema string may be too long to be // stored into a single metastore SerDe property. In this case, we split the JSON string and // store each part as a separate SerDe property. - if (userSpecifiedSchema.isDefined) { + userSpecifiedSchema.foreach { schema => val threshold = conf.schemaStringLengthThreshold - val schemaJsonString = userSpecifiedSchema.get.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) @@ -181,25 +233,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - val metastorePartitionColumns = userSpecifiedSchema.map { schema => - val fields = partitionColumns.map(col => schema(col)) - fields.map { field => - HiveColumn( - name = field.name, - hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), - comment = "") - }.toSeq - }.getOrElse { - if (partitionColumns.length > 0) { - // The table does not have a specified schema, which means that the schema will be inferred - // when we load the table. So, we are not expecting partition columns and we will discover - // partitions when we load the table. However, if there are specified partition columns, - // we simplily ignore them and provide a warning message.. - logWarning( - s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + - s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") + if (userSpecifiedSchema.isDefined && partitionColumns.length > 0) { + tableProperties.put("spark.sql.sources.schema.numPartCols", partitionColumns.length.toString) + partitionColumns.zipWithIndex.foreach { case (partCol, index) => + tableProperties.put(s"spark.sql.sources.schema.partCol.$index", partCol) } - Seq.empty[HiveColumn] + } + + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { + // The table does not have a specified schema, which means that the schema will be inferred + // when we load the table. So, we are not expecting partition columns and we will discover + // partitions when we load the table. However, if there are specified partition columns, + // we simply ignore them and provide a warning message. + logWarning( + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } val tableType = if (isExternal) { @@ -210,53 +258,137 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ManagedTable } - client.createTable( + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val dataSource = ResolvedDataSource( + hive, userSpecifiedSchema, partitionColumns, provider, options) + + def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( specifiedDatabase = Option(dbName), name = tblName, - schema = Seq.empty, - partitionColumns = metastorePartitionColumns, + schema = Nil, + partitionColumns = Nil, tableType = tableType, properties = tableProperties.toMap, - serdeProperties = options)) + serdeProperties = options) + } + + def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { + def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { + schema.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + } + } + + assert(partitionColumns.isEmpty) + assert(relation.partitionColumns.isEmpty) + + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = schemaToHiveColumn(relation.schema), + partitionColumns = Nil, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options, + location = Some(relation.paths.head), + viewText = None, // TODO We need to place the SQL string here. + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde) + } + + // TODO: Support persisting partitioned data source relations in Hive compatible format + val qualifiedTableName = tableIdent.quotedString + val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.relation) match { + case (Some(serde), relation: HadoopFsRelation) + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) + + case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) + + case (Some(serde), relation: HadoopFsRelation) => + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) + + case (Some(serde), _) => + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + + case _ => + val message = + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + } + + (hiveCompatibleTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatible way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } - def hiveDefaultTableFilePath(tableName: String): String = { + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) - new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) + new Path(new Path(client.getDatabase(dbName).location), tblName).toString } - def tableExists(tableIdentifier: Seq[String]): Boolean = { - val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = - tableIdent - .lift(tableIdent.size - 2) - .getOrElse(client.currentDatabase) - val tblName = tableIdent.last - client.getTableOption(databaseName, tblName).isDefined + override def tableExists(tableIdent: TableIdentifier): Boolean = { + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) + client.getTableOption(dbName, tblName).isDefined } - def lookupRelation( - tableIdentifier: Seq[String], + override def lookupRelation( + tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = { - val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - client.currentDatabase) - val tblName = tableIdent.last - val table = client.getTable(databaseName, tblName) + val qualifiedTableName = getQualifiedTableName(tableIdent) + val table = client.getTable(qualifiedTableName.database, qualifiedTableName.name) if (table.properties.get("spark.sql.sources.provider").isDefined) { - val dataSourceTable = - cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + val dataSourceTable = cachedDataSourceTables(qualifiedTableName) + val tableWithQualifiers = Subquery(qualifiedTableName.name, dataSourceTable) // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. - val withAlias = - alias.map(a => Subquery(a, dataSourceTable)).getOrElse( - Subquery(tableIdent.last, dataSourceTable)) - - withAlias + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } else if (table.tableType == VirtualView) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { @@ -266,7 +398,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) } } else { - MetastoreRelation(databaseName, tblName, alias)(table)(hive) + MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) } } @@ -279,7 +411,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + metastoreRelation.tableName, + Some(metastoreRelation.databaseName) + ).unquotedString + ) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -290,7 +427,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical @ LogicalRelation(parquetRelation: ParquetRelation) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = @@ -324,7 +461,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // are empty. val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) @@ -356,7 +493,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive parquetRelation } - result.newInstance() + result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { @@ -365,91 +502,38 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive client.listTables(db).map(tableName => (tableName, false)) } - protected def processDatabaseAndTableName( - databaseName: Option[String], - tableName: String): (Option[String], String) = { - if (!caseSensitive) { - (databaseName.map(_.toLowerCase), tableName.toLowerCase) - } else { - (databaseName, tableName) - } - } - - protected def processDatabaseAndTableName( - databaseName: String, - tableName: String): (String, String) = { - if (!caseSensitive) { - (databaseName.toLowerCase, tableName.toLowerCase) - } else { - (databaseName, tableName) - } - } - /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - if (!plan.resolved) { + if (!plan.resolved || plan.analyzed) { return plan } - // Collects all `MetastoreRelation`s which should be replaced - val toBeReplaced = plan.collect { - // Write path - case InsertIntoTable(relation: MetastoreRelation, _, _, _, _) - // Inserting into partitioned table is not supported in Parquet data source (yet). - if !relation.hiveQlTable.isPartitioned && - hive.convertMetastoreParquet && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(relation) - val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) - + plan transformUp { // Write path - case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _, _) - // Inserting into partitioned table is not supported in Parquet data source (yet). - if !relation.hiveQlTable.isPartitioned && - hive.convertMetastoreParquet && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(relation) - val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) - - // Read path - case p @ PhysicalOperation(_, _, relation: MetastoreRelation) - if hive.convertMetastoreParquet && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(relation) - val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) - } - - val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap - val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) - - // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes - // attribute IDs referenced in other nodes. - plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r) => - val parquetRelation = relationMap(r) - val alias = r.alias.getOrElse(r.tableName) - Subquery(alias, parquetRelation) - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) - if relationMap.contains(r) => - val parquetRelation = relationMap(r) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && + r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(r) InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + // Write path case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) - if relationMap.contains(r) => - val parquetRelation = relationMap(r) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && + r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(r) InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) - case other => other.transformExpressions { - case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) - } + // Read path + case relation: MetastoreRelation if hive.convertMetastoreParquet && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + Subquery(relation.alias.getOrElse(relation.tableName), parquetRelation) } } } @@ -463,8 +547,29 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p + + case CreateViewAsSelect(table, child, allowExisting, replace, sql) => + if (conf.nativeView) { + if (allowExisting && replace) { + throw new AnalysisException( + "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + } + + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) + + execution.CreateViewAsSelect( + table.copy( + specifiedDatabase = Some(dbName), + name = tblName), + child.output, + allowExisting, + replace) + } else { + HiveNativeCommand(sql) + } + case p @ CreateTableAsSelect(table, child, allowExisting) => - val schema = if (table.schema.size > 0) { + val schema = if (table.schema.nonEmpty) { table.schema } else { child.output.map { @@ -487,8 +592,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - desc.name, - hive.conf.defaultDataSourceName, + TableIdentifier(desc.name), + conf.defaultDataSourceName, temporary = false, Array.empty[String], mode, @@ -504,9 +609,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive table } - val (dbName, tblName) = - processDatabaseAndTableName( - desc.specifiedDatabase.getOrElse(client.currentDatabase), desc.name) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateTableAsSelect( desc.copy( @@ -564,7 +667,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } @@ -572,7 +675,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } @@ -593,7 +696,7 @@ private[hive] case class InsertIntoHiveTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -609,8 +712,8 @@ private[hive] case class InsertIntoHiveTable( private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) - (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { + (@transient private val sqlContext: SQLContext) + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -640,20 +743,21 @@ private[hive] case class MetastoreRelation val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.location.foreach(sd.setLocation) table.inputFormat.foreach(sd.setInputFormat) table.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) table.serde.foreach(serdeInfo.setSerializationLib) + sd.setSerdeInfo(serdeInfo) + val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Table(tTable) } @@ -693,11 +797,11 @@ private[hive] case class MetastoreRelation val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) - tPartition.setValues(p.values) + tPartition.setValues(p.values.asJava) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) sd.setLocation(p.storage.location) sd.setInputFormat(p.storage.inputFormat) @@ -705,12 +809,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } @@ -758,6 +863,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6df64d2642b..da41b659e3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,36 +18,40 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.ql.{ErrorMsg, Context} -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer - /** * Used when we need to start parsing the AST before deciding that we are going to pass the command * back for Hive to execute natively. Will be replaced with a native command that contains the @@ -73,6 +77,16 @@ private[hive] case class CreateTableAsSelect( childrenResolved } +private[hive] case class CreateViewAsSelect( + tableDesc: HiveTable, + child: LogicalPlan, + allowExisting: Boolean, + replace: Boolean, + sql: String) extends UnaryNode with Command { + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = false +} + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( @@ -80,6 +94,7 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE", "TOK_ALTERTABLE_ADDCOLS", "TOK_ALTERTABLE_ADDPARTS", "TOK_ALTERTABLE_ALTERPARTS", @@ -103,8 +118,8 @@ private[hive] object HiveQl extends Logging { "TOK_CREATEDATABASE", "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", + "TOK_CREATEMACRO", "TOK_CREATEROLE", - "TOK_CREATEVIEW", "TOK_DESCDATABASE", "TOK_DESCFUNCTION", @@ -112,6 +127,7 @@ private[hive] object HiveQl extends Logging { "TOK_DROPDATABASE", "TOK_DROPFUNCTION", "TOK_DROPINDEX", + "TOK_DROPMACRO", "TOK_DROPROLE", "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", @@ -196,7 +212,7 @@ private[hive] object HiveQl extends Logging { * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. */ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) + Option(s).map(_.asScala).getOrElse(Nil) /** * Returns this ASTNode with the text changed to `newText`. @@ -211,7 +227,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) + n.addChildren(newChildren.asJava) n } @@ -248,22 +264,31 @@ private[hive] object HiveQl extends Logging { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(hiveConf) - val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) + val hContext = createContext() + val node = getAst(sql, hContext) hContext.clear() node } + private def createContext(): Context = new Context(hiveConf) + + private def getAst(sql: String, context: Context) = + ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) + /** * Returns the HiveConf */ - private[this] def hiveConf(): HiveConf = { - val ss = SessionState.get() // SessionState is lazy initializaion, it can be null here + private[this] def hiveConf: HiveConf = { + var ss = SessionState.get() + // SessionState is lazy initialization, it can be null here if (ss == null) { - new HiveConf() - } else { - ss.getConf + val original = Thread.currentThread().getContextClassLoader + val conf = new HiveConf(classOf[SessionState]) + conf.setClassLoader(original) + ss = new SessionState(conf) + SessionState.start(ss) } + ss.getConf } /** Returns a LogicalPlan for a given HiveQL string. */ @@ -274,15 +299,18 @@ private[hive] object HiveQl extends Logging { /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String): LogicalPlan = { try { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { + val context = createContext() + val tree = getAst(sql, context) + val plan = if (nativeCommands contains tree.getText) { HiveNativeCommand(sql) } else { - nodeToPlan(tree) match { + nodeToPlan(tree, context) match { case NativePlaceholder => HiveNativeCommand(sql) case other => other } } + context.clear() + plan } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => pe.getMessage match { @@ -317,26 +345,37 @@ private[hive] object HiveQl extends Logging { assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") val tableOps = tree.getChildren val colList = - tableOps + tableOps.asScala .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") .getOrElse(sys.error("No columnList!")).getChildren - colList.map(nodeToAttribute) + colList.asScala.map(nodeToAttribute) } /** Extractor for matching Hive's AST Tokens. */ + private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { + def getName(): String = name + def getChildren(): java.util.List[Node] = { + val col = new java.util.ArrayList[Node](children.size) + children.foreach(col.add(_)) + col + } + } object Token { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { case t: ASTNode => CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case t: Token => Some((t.name, t.children)) case _ => None } } - protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + protected def getClauses( + clauseNames: Seq[String], + nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { var remainingNodes = nodeList val clauses = clauseNames.map { clauseName => val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) @@ -416,20 +455,12 @@ private[hive] object HiveQl extends Logging { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { - val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } - - (db, tableName) - } - - protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { - case Seq(tableOnly) => Seq(tableOnly) - case Seq(databaseName, table) => Seq(databaseName, table) + protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { + case Seq(tableOnly) => TableIdentifier(tableOnly) + case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) case other => sys.error("Hive only supports tables names like 'tableName' " + s"or 'databaseName.tableName', found '$other'") } @@ -479,7 +510,43 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } - protected def nodeToPlan(node: Node): LogicalPlan = node match { + private def createView( + view: ASTNode, + context: Context, + viewNameParts: ASTNode, + query: ASTNode, + schema: Seq[HiveColumn], + properties: Map[String, String], + allowExist: Boolean, + replace: Boolean): CreateViewAsSelect = { + val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) + + val originalText = context.getTokenRewriteStream + .toString(query.getTokenStartIndex, query.getTokenStopIndex) + + val tableDesc = HiveTable( + specifiedDatabase = dbName, + name = viewName, + schema = schema, + partitionColumns = Seq.empty[HiveColumn], + properties = properties, + serdeProperties = Map[String, String](), + tableType = VirtualView, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = Some(originalText)) + + // We need to keep the original SQL string so that if `spark.sql.nativeView` is + // false, we can fall back to use hive native command later. + // We can remove this when parser is configurable(can access SQLConf) in the future. + val sql = context.getTokenRewriteStream + .toString(view.getTokenStartIndex, view.getTokenStopIndex) + CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) + } + + protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { // Special drop table that also uncaches. case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: @@ -511,14 +578,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( - nodeToPlan(crtTbl), + nodeToPlan(crtTbl, context), extended = extended.isDefined) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( - nodeToPlan(query), + nodeToPlan(query, context), extended = extended.isDefined) case Token("TOK_DESCTABLE", describeArgs) => @@ -545,7 +612,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(Seq(tableName.getText), None), isExtended = extended.isDefined) + UnresolvedRelation(TableIdentifier(tableName.getText), None), + isExtended = extended.isDefined) } } // All other cases. @@ -553,6 +621,73 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } + case view @ Token("TOK_ALTERVIEW", children) => + val Some(viewNameParts) :: maybeQuery :: ignores = + getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME"), children) + + // if ALTER VIEW doesn't have query part, let hive to handle it. + maybeQuery.map { query => + createView(view, context, viewNameParts, query, Nil, Map(), false, true) + }.getOrElse(NativePlaceholder) + + case view @ Token("TOK_CREATEVIEW", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + val Seq( + Some(viewNameParts), + Some(query), + maybeComment, + replace, + allowExisting, + maybeProperties, + maybeColumns, + maybePartCols + ) = getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_ORREPLACE", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), children) + + // If the view is partitioned, we let hive handle it. + if (maybePartCols.isDefined) { + NativePlaceholder + } else { + val schema = maybeColumns.map { cols => + BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + // We can't specify column types when create view, so fill it with null first, and + // update it after the schema has been resolved later. + HiveColumn(field.getName, null, field.getComment) + } + }.getOrElse(Seq.empty[HiveColumn]) + + val properties = scala.collection.mutable.Map.empty[String, String] + + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) + } + + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + if (comment ne null) { + properties += ("comment" -> comment) + } + } + + createView(view, context, viewNameParts, query, schema, properties.toMap, + allowExisting.isDefined, replace.isDefined) + } + case Token("TOK_CREATETABLE", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -577,23 +712,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLESKEWED", // Skewed by "TOK_TABLEROWFORMAT", "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. - "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile - "TOK_TBLTEXTFILE", // Stored as TextFile - "TOK_TBLRCFILE", // Stored as RCFile - "TOK_TBLORCFILE", // Stored as ORC File - "TOK_TBLPARQUETFILE", // Stored as PARQUET + "TOK_FILEFORMAT_GENERIC", "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", "TOK_TABLEPROPERTIES"), children) - val (db, tableName) = extractDbNameTableName(tableNameParts) + val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) // TODO add bucket support var tableDesc: HiveTable = HiveTable( - specifiedDatabase = db, - name = tableName, + specifiedDatabase = dbName, + name = tblName, schema = Seq.empty[HiveColumn], partitionColumns = Seq.empty[HiveColumn], properties = Map[String, String](), @@ -605,45 +735,25 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C serde = None, viewText = None) - // default storage type abbriviation (e.g. RCFile, ORC, PARQUET etc.) + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbriviation - tableDesc = if ("SequenceFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - } else if ("RCFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), - serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))) - } else if ("ORC".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } else if ("PARQUET".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = - Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } else { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) children.collect { case list @ Token("TOK_TABCOLLIST", _) => val cols = BaseSemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( - schema = cols.map { field => + schema = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -655,7 +765,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( - partitionColumns = cols.map { field => + partitionColumns = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -691,7 +801,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => assert(false) } tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams) + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) location = EximUtil.relativeToAbsolutePath(hiveConf, location) @@ -703,39 +813,66 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() BaseSemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - throw new SemanticException( - "Unrecognized file format in STORED AS clause:${child.getText}") + child.getText().toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - case Token("TOK_TBLRCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - case Token("TOK_TBLORCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - case Token("TOK_TBLPARQUETFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } + + case _ => + throw new SemanticException( + s"Unrecognized file format in STORED AS clause: ${child.getText}") } case Token("TOK_TABLESERIALIZER", @@ -751,7 +888,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTIES", list :: Nil) => tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => + case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), @@ -762,7 +899,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => // Unsupport features } - CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting != None) + CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) // If its not a "CTAS" like above then take it as a native command case Token("TOK_CREATETABLE", _) => NativePlaceholder @@ -781,7 +918,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C insertClauses.last match { case Token("TOK_CTE", cteClauses) => val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node).asInstanceOf[Subquery] + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] (relation.alias, relation) }).toMap (Some(args.head), insertClauses.init, Some(cteRelations)) @@ -835,12 +972,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val relations = fromClause match { - case Some(f) => nodeToRelation(f) + case Some(f) => nodeToRelation(f, context) case None => OneRowRelation } val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.toSeq + val Seq(whereExpr) = whereNode.getChildren.asScala Filter(nodeToExpr(whereExpr), relations) }.getOrElse(relations) @@ -849,7 +986,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Script transformations are expressed as a select clause with a single expression of type // TOK_TRANSFORM - val transformation = select.getChildren.head match { + val transformation = select.getChildren.iterator().next() match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -873,38 +1010,72 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (name, value) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, None, Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( @@ -917,10 +1088,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withLateralView = lateralViewClause.map { lv => val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - val alias = - getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -936,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -965,7 +1136,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Handle HAVING clause. val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } // Note that we added a cast to boolean. If the expression itself is already boolean, // the optimizer will get rid of the unnecessary cast. Filter(Cast(havingExpr, BooleanType), withProject) @@ -975,32 +1146,42 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withDistinct = if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - // Handle ORDER BY, SORT BY, DISTRIBETU BY, and CLUSTER BY clause. + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withDistinct) + Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) case (None, Some(perPartitionOrdering), None, None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withDistinct) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), + false, withDistinct) case (None, None, Some(partitionExprs), None) => - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, Some(clusterExprs)) => - Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), + false, + RepartitionByExpression( + clusterExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, None) => withDistinct case _ => sys.error("Unsupported set of ordering / distribution clauses.") } val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.head)) + limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) .map(Limit(_, withSort)) .getOrElse(withSort) // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.toSeq.collect { + val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { case Token("TOK_WINDOWDEF", Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => windowName -> nodesToWindowSpecification(spec) @@ -1037,24 +1218,27 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) - case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => + Union(nodeToPlan(left, context), nodeToPlan(right, context)) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") } val allJoinTokens = "(TOK_.*JOIN)".r val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node): LogicalPlan = node match { + def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -1063,7 +1247,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C outer = isOuter.nonEmpty, Some(alias.toLowerCase), attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause)) + nodeToRelation(relationClause, context)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -1082,13 +1266,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nonAliasClauses) } - val tableIdent = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => Seq(tableOnly) - case Seq(databaseName, table) => Seq(databaseName, table) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } + val tableIdent = extractTableIdent(tableNameParts) val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } val relation = UnresolvedRelation(tableIdent, alias) @@ -1129,8 +1307,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }.map(_._2) val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) - val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) + val joinExpressions = + tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => @@ -1155,7 +1334,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C joinType = joinType.remove(joinType.length - 1)) } - val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) // Unique join is not really the same as an outer join so we must group together results where // the joinExpressions are the same, taking the First of each value is only okay because the @@ -1183,8 +1362,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - Join(nodeToRelation(relation1), - nodeToRelation(relation2), + Join(nodeToRelation(relation1, context), + nodeToRelation(relation2, context), joinType, other.headOption.map(nodeToExpr)) @@ -1220,7 +1399,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1240,7 +1419,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1251,7 +1430,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName}:" + + s"\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { @@ -1274,7 +1454,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_HINTLIST", _) => None case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName }:" + + s"\n ${dumpTree(a).toString } ") } protected val escapedIdentifier = "`([^`]+)`".r @@ -1334,12 +1515,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(name)) + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => @@ -1444,17 +1626,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) + case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(Token(name, args)) nodesToWindowSpecification(spec) match { case reference: WindowSpecReference => UnresolvedWindowExpression(function, reference) @@ -1520,6 +1693,30 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : @@ -1555,18 +1752,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = getClauses( Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.toSeq.asInstanceOf[Seq[ASTNode]]) + partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.map(nodeToExpr), - orderByExpr.getChildren.map(nodeToSortOrder)) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), + orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.map(nodeToExpr), Nil) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.map(nodeToSortOrder)) + (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.map(nodeToExpr) + val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) (expressions, expressions.map(SortOrder(_, Ascending))) case _ => throw new NotImplementedError( @@ -1604,7 +1801,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.toList match { + frame.getChildren.asScala.toList match { case precedingNode :: followingNode :: Nil => SpecifiedWindowFrame( frameType, @@ -1624,6 +1821,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head @@ -1636,6 +1834,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => (Explode(nodeToExpr(child)), attributes) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + (JsonTuple(children.map(nodeToExpr)), attributes) + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( @@ -1666,7 +1867,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case other => sys.error(s"Non ASTNode encountered: $other") } - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) builder } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7f..f0697613cff3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.avro.Schema + +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -33,7 +34,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable @@ -71,7 +72,7 @@ private[hive] object HiveShim { */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { if (ids != null && ids.nonEmpty) { - ColumnProjectionUtils.appendReadColumns(conf, ids) + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) @@ -82,10 +83,19 @@ private[hive] object HiveShim { * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. */ - def prepareWritable(w: Writable): Writable = { + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { w match { case w: AvroGenericRecordWritable => w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } case _ => } w @@ -107,9 +117,10 @@ private[hive] object HiveShim { * Detail discussion can be found at https://github.com/apache/spark/pull/3640 * * @param functionClassName UDF class name + * @param instance optional UDF instance which contains additional information (for macro) */ - private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { + private[hive] case class HiveFunctionWrapper(var functionClassName: String, + private var instance: AnyRef = null) extends java.io.Externalizable { // for Serialization def this() = this(null) @@ -144,8 +155,6 @@ private[hive] object HiveShim { serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) } - private var instance: AnyRef = null - def writeExternal(out: java.io.ObjectOutput) { // output the function name out.writeUTF(functionClassName) @@ -174,7 +183,7 @@ private[hive] object HiveShim { // read the function in bytes val functionInBytesLength = in.readInt() val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) + in.readFully(functionInBytes) // deserialize the function object via Hive Utilities instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cd6cd322c94e..d38ad9127327 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,14 +83,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => - ExecutedCommand( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + val cmd = CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) + ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index dc355690852b..70ee02823eeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.hive +import java.util + import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} -import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable, Hive, HiveUtils, HiveStorageHandler} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} @@ -54,10 +56,10 @@ private[hive] sealed trait TableReader { */ private[hive] class HadoopTableReader( - @transient attributes: Seq[Attribute], - @transient relation: MetastoreRelation, - @transient sc: HiveContext, - @transient hiveExtraConf: HiveConf) + @transient private val attributes: Seq[Attribute], + @transient private val relation: MetastoreRelation, + @transient private val sc: HiveContext, + hiveExtraConf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". @@ -287,6 +289,29 @@ class HadoopTableReader( } } +private[hive] object HiveTableUtil { + + // copied from PlanUtils.configureJobPropertiesForStorageHandler(tableDesc) + // that calls Hive.get() which tries to access metastore, but it's not valid in runtime + // it would be fixed in next version of hive but till then, we should use this instead + def configureJobPropertiesForStorageHandler( + tableDesc: TableDesc, jobConf: JobConf, input: Boolean) { + val property = tableDesc.getProperties.getProperty(META_TABLE_STORAGE) + val storageHandler = HiveUtils.getStorageHandler(jobConf, property) + if (storageHandler != null) { + val jobProperties = new util.LinkedHashMap[String, String] + if (input) { + storageHandler.configureInputJobProperties(tableDesc, jobProperties) + } else { + storageHandler.configureOutputJobProperties(tableDesc, jobProperties) + } + if (!jobProperties.isEmpty) { + tableDesc.setJobProperties(jobProperties) + } + } + } +} + private[hive] object HadoopTableReader extends HiveInspectors with Logging { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to @@ -295,7 +320,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) if (tableDesc != null) { - PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) + HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, jobConf, true) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } val bufferSize = System.getProperty("spark.buffer.size", "65536") @@ -357,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index d834b4e83e04..9d9a55edd731 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.hive.client import java.io.PrintStream import java.util.{Map => JMap} +import javax.annotation.Nullable import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression -private[hive] case class HiveDatabase( - name: String, - location: String) +private[hive] case class HiveDatabase(name: String, location: String) private[hive] abstract class TableType { val name: String } private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } @@ -45,7 +44,7 @@ private[hive] case class HivePartition( values: Seq[String], storage: HiveStorageDescriptor) -private[hive] case class HiveColumn(name: String, hiveType: String, comment: String) +private[hive] case class HiveColumn(name: String, @Nullable hiveType: String, comment: String) private[hive] case class HiveTable( specifiedDatabase: Option[String], name: String, @@ -87,6 +86,13 @@ private[hive] case class HiveTable( * shared classes. */ private[hive] trait ClientInterface { + + /** Returns the Hive Version of this client. */ + def version: HiveVersion + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will * result in one string. @@ -119,6 +125,12 @@ private[hive] trait ClientInterface { /** Returns the metadata for the specified table or None if it doens't exist. */ def getTableOption(dbName: String, tableName: String): Option[HiveTable] + /** Creates a view with the given metadata. */ + def createView(view: HiveTable): Unit + + /** Updates the given view with new metadata. */ + def alertView(view: HiveTable): Unit + /** Creates a table with the given metadata. */ def createTable(table: HiveTable): Unit @@ -166,6 +178,15 @@ private[hive] trait ClientInterface { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + /** Add a jar into class loader */ + def addJar(path: String): Unit + + /** Return a ClientInterface as new session, that will share the class loader and Hive client */ + def newSession(): ClientInterface + + /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ + def withHiveState[A](f: => A): A + /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 6e0912da5862..598ccdeee4ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} import java.util.{Map => JMap} -import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path @@ -32,13 +31,15 @@ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.VersionInfo -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, SparkException, Logging} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} - /** * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, @@ -57,12 +58,81 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], - initClassLoader: ClassLoader) + initClassLoader: ClassLoader, + val clientLoader: IsolatedClientLoader) extends ClientInterface with Logging { + overrideHadoopShims() + + // !! HACK ALERT !! + // + // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking + // major version number gathered from Hadoop jar files: + // + // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with + // security. + // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. + // + // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It + // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but + // `Hadoop23Shims` is chosen because the major version number here is 2. + // + // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` + // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is + // not available, we decide whether to override the shims or not by checking for existence of a + // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. + private def overrideHadoopShims(): Unit = { + val hadoopVersion = VersionInfo.getVersion + val VersionPattern = """(\d+)\.(\d+).*""".r + + hadoopVersion match { + case null => + logError("Failed to inspect Hadoop version") + + // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. + val probeMethod = "getPathWithoutSchemeAndAuthority" + if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { + logInfo( + s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + + s"we are probably using Hadoop 1.x or 2.0.x") + loadHadoop20SShims() + } + + case VersionPattern(majorVersion, minorVersion) => + logInfo(s"Inspected Hadoop version: $hadoopVersion") + + // Loads Hadoop20SShims for 1.x and 2.0.x versions + val (major, minor) = (majorVersion.toInt, minorVersion.toInt) + if (major < 2 || (major == 2 && minor == 0)) { + loadHadoop20SShims() + } + } + + // Logs the actual loaded Hadoop shims class + val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName + logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") + } + + private def loadHadoop20SShims(): Unit = { + val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" + logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") + + try { + val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") + // scalastyle:off classforname + val shimsClass = Class.forName(hadoop20SShimsClassName) + // scalastyle:on classforname + val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) + shimsField.setAccessible(true) + shimsField.set(null, shims) + } catch { case cause: Throwable => + throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) + } + } + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. private val outputBuffer = new CircularBuffer() @@ -80,32 +150,51 @@ private[hive] class ClientWrapper( val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) + + // Set up kerberos credentials for UserGroupInformation.loginUser within + // current class loader + // Instead of using the spark conf of the current spark context, a new + // instance of SparkConf is needed for the original value of spark.yarn.keytab + // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the + // keytab configuration for the link name in distributed cache + val sparkConf = new SparkConf + if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { + val principalName = sparkConf.get("spark.yarn.principal") + val keytabFileName = sparkConf.get("spark.yarn.keytab") + if (!new File(keytabFileName).exists()) { + throw new SparkException(s"Keytab file: ${keytabFileName}" + + " specified in spark.yarn.keytab does not exist") + } else { + logInfo("Attempting to login to Kerberos" + + s" using principal: ${principalName} and keytab: ${keytabFileName}") + UserGroupInformation.loginUserFromKeytab(principalName, keytabFileName) + } + } + val ret = try { - val oldState = SessionState.get() - if (oldState == null) { - val initialConf = new HiveConf(classOf[SessionState]) - // HiveConf is a Hadoop Configuration, which has a field of classLoader and - // the initial value will be the current thread's context class loader - // (i.e. initClassLoader at here). - // We call initialConf.setClassLoader(initClassLoader) at here to make - // this action explicit. - initialConf.setClassLoader(initClassLoader) - config.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { - logDebug(s"Hive Config: $k=xxx") - } else { - logDebug(s"Hive Config: $k=$v") - } - initialConf.set(k, v) + val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) + config.foreach { case (k, v) => + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") } - val newState = new SessionState(initialConf) - SessionState.start(newState) - newState.out = new PrintStream(outputBuffer, true, "UTF-8") - newState.err = new PrintStream(outputBuffer, true, "UTF-8") - newState - } else { - oldState + initialConf.set(k, v) + } + val state = new SessionState(initialConf) + if (clientLoader.cachedHive != null) { + Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) } + SessionState.start(state) + state.out = new PrintStream(outputBuffer, true, "UTF-8") + state.err = new PrintStream(outputBuffer, true, "UTF-8") + state } finally { Thread.currentThread().setContextClassLoader(original) } @@ -115,10 +204,9 @@ private[hive] class ClientWrapper( /** Returns the configuration for the current session. */ def conf: HiveConf = SessionState.get().getConf - // TODO: should be a def?s - // When we create this val client, the HiveConf of it (conf) is the one associated with state. - @GuardedBy("this") - private var client = Hive.get(conf) + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } // We use hive's conf for compatibility. private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) @@ -127,7 +215,7 @@ private[hive] class ClientWrapper( /** * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. */ - private def retryLocked[A](f: => A): A = synchronized { + private def retryLocked[A](f: => A): A = clientLoader.synchronized { // Hive sometimes retries internally, so set a deadline to avoid compounding delays. val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong var numTries = 0 @@ -142,13 +230,8 @@ private[hive] class ClientWrapper( logWarning( "HiveClientWrapper got thrift exception, destroying client and retrying " + s"(${retryLimit - numTries} tries remaining)", e) + clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) - try { - client = Hive.get(state.getConf, true) - } catch { - case e: Exception if causedByThrift(e) => - logWarning("Failed to refresh hive client, will retry.", e) - } } } while (numTries <= retryLimit && System.nanoTime < deadline) if (System.nanoTime > deadline) { @@ -169,13 +252,26 @@ private[hive] class ClientWrapper( false } + def client: Hive = { + if (clientLoader.cachedHive != null) { + clientLoader.cachedHive.asInstanceOf[Hive] + } else { + val c = Hive.get(conf) + clientLoader.cachedHive = c + c + } + } + /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = retryLocked { + def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) + // The classloader in clientLoader could be changed after addJar, always use the latest + // classloader + state.getConf.setClassLoader(clientLoader.classLoader) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. @@ -232,10 +328,11 @@ private[hive] class ClientWrapper( HiveTable( name = h.getTableName, specifiedDatabase = Option(h.getDbName), - schema = h.getCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, + schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.asScala.map(f => + HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.asScala.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, tableType = h.getTableType match { case HTableType.MANAGED_TABLE => ManagedTable case HTableType.EXTERNAL_TABLE => ExternalTable @@ -261,9 +358,9 @@ private[hive] class ClientWrapper( private def toQlTable(table: HiveTable): metadata.Table = { val qlTable = new metadata.Table(table.database, table.name) - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } @@ -280,6 +377,37 @@ private[hive] class ClientWrapper( qlTable } + private def toViewTable(view: HiveTable): metadata.Table = { + // TODO: this is duplicated with `toQlTable` except the table type stuff. + val tbl = new metadata.Table(view.database, view.name) + tbl.setTableType(HTableType.VIRTUAL_VIEW) + tbl.setSerializationLib(null) + tbl.clearSerDeInfo() + + // TODO: we will save the same SQL string to original and expanded text, which is different + // from Hive. + tbl.setViewOriginalText(view.viewText.get) + tbl.setViewExpandedText(view.viewText.get) + + tbl.setFields(view.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + view.properties.foreach { case (k, v) => tbl.setProperty(k, v) } + + // set owner + tbl.setOwner(conf.getUser) + // set create time + tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) + + tbl + } + + override def createView(view: HiveTable): Unit = withHiveState { + client.createTable(toViewTable(view)) + } + + override def alertView(view: HiveTable): Unit = withHiveState { + client.alterTable(view.qualifiedName, toViewTable(view)) + } + override def createTable(table: HiveTable): Unit = withHiveState { val qlTable = toQlTable(table) client.createTable(qlTable) @@ -293,13 +421,13 @@ private[hive] class ClientWrapper( private def toHivePartition(partition: metadata.Partition): HivePartition = { val apiPartition = partition.getTPartition HivePartition( - values = Option(apiPartition.getValues).map(_.toSeq).getOrElse(Seq.empty), + values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), storage = HiveStorageDescriptor( location = apiPartition.getSd.getLocation, inputFormat = apiPartition.getSd.getInputFormat, outputFormat = apiPartition.getSd.getOutputFormat, serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.toMap)) + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } override def getPartitionOption( @@ -324,7 +452,7 @@ private[hive] class ClientWrapper( } override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName) + client.getAllTables(dbName).asScala } /** @@ -440,18 +568,35 @@ private[hive] class ClientWrapper( listBucketingEnabled) } + def addJar(path: String): Unit = { + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + clientLoader.addJar(jarURL) + runSqlHive(s"ADD JAR $path") + } + + def newSession(): ClientWrapper = { + clientLoader.createClient().asInstanceOf[ClientWrapper] + } + def reset(): Unit = withHiveState { - client.getAllTables("default").foreach { t => + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) - client.getIndexes("default", t, 255).foreach { index => + client.getIndexes("default", t, 255).asScala.foreach { index => shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) } } - client.getAllDatabases.filterNot(_ == "default").foreach { db => + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 956997e5f9dc..346840079b85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -23,9 +23,9 @@ import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -201,7 +201,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { setDataLocationMethod.invoke(table, new URI(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq override def getPartitionsByFilter( hive: Hive, @@ -220,7 +220,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() getDriverResultsMethod.invoke(driver, res) - res.toSeq + res.asScala } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -310,7 +310,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { setDataLocationMethod.invoke(table, new Path(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq /** * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. @@ -320,8 +320,9 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { @@ -354,7 +355,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] } - partitions.toSeq + partitions.asScala.toSeq } override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = @@ -363,7 +364,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[Object]() getDriverResultsMethod.invoke(driver, res) - res.map { r => + res.asScala.map { r => r match { case s: String => s case a: Array[Object] => a(0).asInstanceOf[String] @@ -429,7 +430,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -439,7 +440,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -461,6 +462,13 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, TimeUnit.MILLISECONDS).asInstanceOf[Long] } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + } private[client] class Shim_v1_0 extends Shim_v0_14 { @@ -512,7 +520,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0: JLong) + 0L: JLong) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 97fb98199991..010051d255fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -29,24 +29,58 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.Utils - import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ -private[hive] object IsolatedClientLoader { +private[hive] object IsolatedClientLoader extends Logging { /** * Creates isolated Hive client loaders by downloading the requested version from maven. */ def forVersion( - version: String, + hiveMetastoreVersion: String, + hadoopVersion: String, config: Map[String, String] = Map.empty, - ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { - val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, - downloadVersion(resolvedVersion, ivyPath)) - new IsolatedClientLoader(hiveVersion(version), files, config) + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + val resolvedVersion = hiveVersion(hiveMetastoreVersion) + // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact + // with the given version, we will use Hadoop 2.4.0 and then will not share Hadoop classes. + var sharesHadoopClasses = true + val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { + resolvedVersions((resolvedVersion, hadoopVersion)) + } else { + val (downloadedFiles, actualHadoopVersion) = + try { + (downloadVersion(resolvedVersion, hadoopVersion, ivyPath), hadoopVersion) + } catch { + case e: RuntimeException if e.getMessage.contains("hadoop") => + // If the error message contains hadoop, it is probably because the hadoop + // version cannot be resolved (e.g. it is a vendor specific version like + // 2.0.0-cdh4.1.1). If it is the case, we will try just + // "org.apache.hadoop:hadoop-client:2.4.0". "org.apache.hadoop:hadoop-client:2.4.0" + // is used just because we used to hard code it as the hadoop artifact to download. + logWarning(s"Failed to resolve Hadoop artifacts for the version ${hadoopVersion}. " + + s"We will change the hadoop version from ${hadoopVersion} to 2.4.0 and try again. " + + "Hadoop classes will not be shared between Spark and Hive metastore client. " + + "It is recommended to set jars used by Hive metastore client through " + + "spark.sql.hive.metastore.jars in the production environment.") + sharesHadoopClasses = false + (downloadVersion(resolvedVersion, "2.4.0", ivyPath), "2.4.0") + } + resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) + resolvedVersions((resolvedVersion, actualHadoopVersion)) + } + + new IsolatedClientLoader( + version = hiveVersion(hiveMetastoreVersion), + execJars = files, + config = config, + sharesHadoopClasses = sharesHadoopClasses, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { @@ -55,15 +89,18 @@ private[hive] object IsolatedClientLoader { case "14" | "0.14" | "0.14.0" => hive.v14 case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 - case "1.2" | "1.2.0" => hive.v1_2 + case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + private def downloadVersion( + version: HiveVersion, + hadoopVersion: String, + ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ Seq("com.google.guava:guava:14.0.1", - "org.apache.hadoop:hadoop-client:2.4.0") + s"org.apache.hadoop:hadoop-client:$hadoopVersion") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -77,10 +114,13 @@ private[hive] object IsolatedClientLoader { // TODO: Remove copy logic. val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) - tempDir.listFiles().map(_.toURL) + tempDir.listFiles().map(_.toURI.toURL) } - private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] + // A map from a given pair of HiveVersion and Hadoop version to jar files. + // It is only used by forVersion. + private val resolvedVersions = + new scala.collection.mutable.HashMap[(HiveVersion, String), Seq[URL]] } /** @@ -100,6 +140,7 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. + * @param sharesHadoopClasses When true, we will share Hadoop classes between Spark and * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. */ @@ -108,6 +149,7 @@ private[hive] class IsolatedClientLoader( val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, + val sharesHadoopClasses: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, val sharedPrefixes: Seq[String] = Seq.empty, @@ -120,15 +162,20 @@ private[hive] class IsolatedClientLoader( /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray - protected def isSharedClass(name: String): Boolean = + protected def isSharedClass(name: String): Boolean = { + val isHadoopClass = + name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") + name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || + (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) + } /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = @@ -139,54 +186,98 @@ private[hive] class IsolatedClientLoader( protected def classToPath(name: String): String = name.replaceAll("\\.", "/") + ".class" - /** The classloader that is used to load an isolated version of Hive. */ - protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) { - override def loadClass(name: String, resolve: Boolean): Class[_] = { - val loaded = findLoadedClass(name) - if (loaded == null) doLoadClass(name, resolve) else loaded - } - - def doLoadClass(name: String, resolve: Boolean): Class[_] = { - val classFileName = name.replaceAll("\\.", "/") + ".class" - if (isBarrierClass(name) && isolationOn) { - // For barrier classes, we construct a new copy of the class. - val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) - logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") - defineClass(name, bytes, 0, bytes.length) - } else if (!isSharedClass(name)) { - logDebug(s"hive class: $name - ${getResource(classToPath(name))}") - super.loadClass(name, resolve) + /** + * The classloader that is used to load an isolated version of Hive. + * This classloader is a special URLClassLoader that exposes the addURL method. + * So, when we add jar, we can add this new jar directly through the addURL method + * instead of stacking a new URLClassLoader on top of it. + */ + private[hive] val classLoader: MutableURLClassLoader = { + val isolatedClassLoader = + if (isolationOn) { + new URLClassLoader(allJars, rootClassLoader) { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + val loaded = findLoadedClass(name) + if (loaded == null) doLoadClass(name, resolve) else loaded + } + def doLoadClass(name: String, resolve: Boolean): Class[_] = { + val classFileName = name.replaceAll("\\.", "/") + ".class" + if (isBarrierClass(name)) { + // For barrier classes, we construct a new copy of the class. + val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) + logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") + defineClass(name, bytes, 0, bytes.length) + } else if (!isSharedClass(name)) { + logDebug(s"hive class: $name - ${getResource(classToPath(name))}") + super.loadClass(name, resolve) + } else { + // For shared classes, we delegate to baseClassLoader. + logDebug(s"shared class: $name") + baseClassLoader.loadClass(name) + } + } + } } else { - // For shared classes, we delegate to baseClassLoader. - logDebug(s"shared class: $name") - baseClassLoader.loadClass(name) + baseClassLoader } - } + // Right now, we create a URLClassLoader that gives preference to isolatedClassLoader + // over its own URLs when it loads classes and resources. + // We may want to use ChildFirstURLClassLoader based on + // the configuration of spark.executor.userClassPathFirst, which gives preference + // to its own URLs over the parent class loader (see Executor's createClassLoader method). + new NonClosableMutableURLClassLoader(isolatedClassLoader) } - // Pre-reflective instantiation setup. - logDebug("Initializing the logger to avoid disaster...") - Thread.currentThread.setContextClassLoader(classLoader) + private[hive] def addJar(path: URL): Unit = synchronized { + classLoader.addURL(path) + } /** The isolated client interface to Hive. */ - val client: ClientInterface = try { - classLoader - .loadClass(classOf[ClientWrapper].getName) - .getConstructors.head - .newInstance(version, config, classLoader) - .asInstanceOf[ClientInterface] - } catch { - case e: InvocationTargetException => - if (e.getCause().isInstanceOf[NoClassDefFoundError]) { - val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] - throw new ClassNotFoundException( - s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + - "Please make sure that jars for your version of hive and hadoop are included in the " + - s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") - } else { - throw e - } - } finally { - Thread.currentThread.setContextClassLoader(baseClassLoader) + private[hive] def createClient(): ClientInterface = { + if (!isolationOn) { + return new ClientWrapper(version, config, baseClassLoader, this) + } + // Pre-reflective instantiation setup. + logDebug("Initializing the logger to avoid disaster...") + val origLoader = Thread.currentThread().getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + + try { + classLoader + .loadClass(classOf[ClientWrapper].getName) + .getConstructors.head + .newInstance(version, config, classLoader, this) + .asInstanceOf[ClientInterface] + } catch { + case e: InvocationTargetException => + if (e.getCause().isInstanceOf[NoClassDefFoundError]) { + val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] + throw new ClassNotFoundException( + s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + + "Please make sure that jars for your version of hive and hadoop are included in the " + + s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + } else { + throw e + } + } finally { + Thread.currentThread.setContextClassLoader(origLoader) + } } + + /** + * The place holder for shared Hive client for all the HiveContext sessions (they share an + * IsolatedClientLoader). + */ + private[hive] var cachedHive: Any = null +} + +/** + * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. + * This class loader cannot be closed (its `close` method is a no-op). + */ +private[sql] class NonClosableMutableURLClassLoader( + parent: ClassLoader) + extends MutableURLClassLoader(Array.empty, parent) { + + override def close(): Unit = {} } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index b48082fe4b36..b1b8439efa01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") @@ -56,7 +56,7 @@ package object client { "net.hydromatic:linq4j", "net.hydromatic:quidem")) - case object v1_2 extends HiveVersion("1.2.0", + case object v1_2 extends HiveVersion("1.2.1", exclusions = Seq("eigenbase:eigenbase-properties", "org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 84358cb73c9e..e72a60b42e65 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} @@ -37,8 +38,9 @@ case class CreateTableAsSelect( allowExisting: Boolean) extends RunnableCommand { - def database: String = tableDesc.database - def tableName: String = tableDesc.name + val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) + + override def children: Seq[LogicalPlan] = Seq(query) override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] @@ -70,18 +72,18 @@ case class CreateTableAsSelect( hiveContext.catalog.client.createTable(withSchema) // Get the Metastore Relation - hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match { + hiveContext.catalog.lookupRelation(tableIdentifier, None) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.catalog.tableExists(Seq(database, tableName))) { + if (hiveContext.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { - throw new AnalysisException(s"$database.$tableName already exists.") + throw new AnalysisException(s"$tableIdentifier already exists.") } } else { hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd @@ -91,6 +93,6 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name}, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala new file mode 100644 index 000000000000..2c81115ee4fe --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} + +/** + * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of + * depending on Hive meta-store. + */ +// TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is different +// from Hive and may not work for some cases like create view on self join. +private[hive] case class CreateViewAsSelect( + tableDesc: HiveTable, + childSchema: Seq[Attribute], + allowExisting: Boolean, + orReplace: Boolean) extends RunnableCommand { + + assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) + assert(tableDesc.viewText.isDefined) + + val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) + + override def run(sqlContext: SQLContext): Seq[Row] = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + + if (hiveContext.catalog.tableExists(tableIdentifier)) { + if (allowExisting) { + // view already exists, will do nothing, to keep consistent with Hive + } else if (orReplace) { + hiveContext.catalog.client.alertView(prepareTable()) + } else { + throw new AnalysisException(s"View $tableIdentifier already exists. " + + "If you want to update the view definition, please use ALTER VIEW AS or " + + "CREATE OR REPLACE VIEW AS") + } + } else { + hiveContext.catalog.client.createView(prepareTable()) + } + + Seq.empty[Row] + } + + private def prepareTable(): HiveTable = { + // setup column types according to the schema of child. + val schema = if (tableDesc.schema == Nil) { + childSchema.map { attr => + HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + } + } else { + childSchema.zip(tableDesc.schema).map { case (attr, col) => + HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + } + } + + val columnNames = childSchema.map(f => verbose(f.name)) + + // When user specified column names for view, we should create a project to do the renaming. + // When no column name specified, we still need to create a project to declare the columns + // we need, to make us more robust to top level `*`s. + val projectList = if (tableDesc.schema == Nil) { + columnNames.mkString(", ") + } else { + columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map { + case (name, alias) => s"$name AS $alias" + }.mkString(", ") + } + + val viewName = verbose(tableDesc.name) + + val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName" + + tableDesc.copy(schema = schema, viewText = Some(expandedText)) + } + + // escape backtick with double-backtick in column name and wrap it with backtick. + private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 5f0ed5393d19..441b6b6033e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema @@ -39,8 +39,8 @@ case class DescribeHiveTableCommand( // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala results ++= columns.map(field => (field.getName, field.getType, field.getComment)) if (partitionColumns.nonEmpty) { val partColumnInfo = @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( results ++= partColumnInfo ++ Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ partColumnInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c..806d2b9b0b7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} @@ -98,7 +98,7 @@ case class HiveTableScan( .asInstanceOf[StructObjectInspector] val columnTypeNames = structOI - .getAllStructFieldRefs + .getAllStructFieldRefs.asScala .map(_.getFieldObjectInspector) .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) .mkString(",") @@ -118,9 +118,8 @@ case class HiveTableScan( case None => partitions case Some(shouldKeep) => partitions.filter { part => val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } + val castedValues = part.getValues.asScala.zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 40a6a3215668..f936cf565b2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -38,8 +39,6 @@ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} - -import scala.collection.JavaConversions._ import org.apache.spark.util.SerializableJobConf private[hive] @@ -61,9 +60,9 @@ case class InsertIntoHiveTable( serializer } - def output: Seq[Attribute] = child.output + def output: Seq[Attribute] = Seq.empty - def saveAsHiveFile( + private def saveAsHiveFile( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, @@ -94,7 +93,8 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val fieldOIs = standardOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector).toArray val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} val outputData = new Array[Any](fieldOIs.length) @@ -124,12 +124,12 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = { + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) @@ -178,6 +178,19 @@ case class InsertIntoHiveTable( val jobConf = new JobConf(sc.hiveconf) val jobConfSer = new SerializableJobConf(jobConf) + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) @@ -198,7 +211,7 @@ case class InsertIntoHiveTable( // loadPartition call orders directories created on the iteration order of the this map val orderedPartitionSpec = new util.LinkedHashMap[String, String]() - table.hiveQlTable.getPartCols().foreach { entry => + table.hiveQlTable.getPartCols.asScala.foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } @@ -226,7 +239,7 @@ case class InsertIntoHiveTable( val oldPart = catalog.client.getPartitionOption( catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec) + partitionSpec.asJava) if (oldPart.isEmpty || !ifNotExists) { catalog.client.loadPartition( @@ -254,10 +267,10 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[Row] + Seq.empty[InternalRow] } - override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 7e3342cc84c0..b30117f0de99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -21,24 +21,26 @@ import java.io._ import java.util.Properties import javax.annotation.Nullable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.io.Writable -import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} +import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. @@ -53,20 +55,23 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd) + val builder = new ProcessBuilder(cmd.asJava) val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -88,6 +93,7 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, + input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -95,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -106,9 +113,19 @@ case class ScriptTransformation( val reader = new BufferedReader(new InputStreamReader(inputStream)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { - var cacheRow: InternalRow = null var curLine: String = null - var eof: Boolean = false + val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) override def hasNext: Boolean = { if (outputSerde == null) { @@ -125,45 +142,30 @@ case class ScriptTransformation( } else { true } - } else { - if (eof) { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - false - } else { - true - } - } - } - - def deserialize(): InternalRow = { - if (cacheRow != null) return cacheRow - - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) - try { - val dataInputStream = new DataInputStream(inputStream) - val writable = outputSerde.getSerializedClass().newInstance - writable.readFields(dataInputStream) + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject - val raw = outputSerde.deserialize(writable) - val dataList = outputSoi.getStructFieldsDataAsList(raw) - val fieldList = outputSoi.getAllStructFieldRefs() - - var i = 0 - dataList.foreach( element => { - if (element == null) { - mutableRow.setNullAt(i) + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) + false } else { - mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) + true } - i += 1 - }) - mutableRow - } catch { - case e: EOFException => - eof = true - null + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } + } + } else { + true } } @@ -171,7 +173,6 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -185,12 +186,20 @@ case class ScriptTransformation( .map(CatalystTypeConverters.convertToCatalyst)) } } else { - val ret = deserialize() - if (!eof) { - cacheRow = null - cacheRow = deserialize() + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector) + } + i += 1 } - ret + mutableRow } } } @@ -213,6 +222,7 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], + inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -220,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -234,20 +245,39 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true + val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - outputStream.write(data) + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes("utf-8")) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -287,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -304,7 +336,7 @@ case class HiveScriptIOSchema ( val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) .asInstanceOf[ObjectInspector] (serde, objectInspector) } @@ -320,18 +352,8 @@ case class HiveScriptIOSchema ( } private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.map { - case aref: AttributeReference => aref.name - case e: NamedExpression => e.name - case _ => null - } - - val columnTypes = attrs.map { - case aref: AttributeReference => aref.dataType - case e: NamedExpression => e.dataType - case _ => null - } - + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) (columns, columnTypes) } @@ -345,15 +367,33 @@ case class HiveScriptIOSchema ( val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() - properties.putAll(propsMap) + properties.putAll(propsMap.asJava) serde.initialize(null, properties) serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index a47f9a4feb21..94210a5394f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.hive.metastore.MetaStoreUtils + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * Analyzes the given table in the current database to generate statistics, which will be @@ -69,7 +71,7 @@ case class DropTable( } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.catalog.unregisterTable(Seq(tableName)) + hiveContext.catalog.unregisterTable(TableIdentifier(tableName)) Seq.empty[Row] } } @@ -84,26 +86,7 @@ case class AddJar(path: String) extends RunnableCommand { } override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val currentClassLoader = Utils.getContextOrSparkClassLoader - - // Add jar to current context - val jarURL = new java.io.File(path).toURL - val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) - Thread.currentThread.setContextClassLoader(newClassLoader) - // We need to explicitly set the class loader associated with the conf in executionHive's - // state because this class loader will be used as the context class loader of the current - // thread to execute any Hive command. - // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get() - // returns the value of a thread local variable and its HiveConf may not be the HiveConf - // associated with `executionHive.state` (for example, HiveContext is created in one thread - // and then add jar is called from another thread). - hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader) - // Add jar to isolated hive (metadataHive) class loader. - hiveContext.runSqlHive(s"ADD JAR $path") - - // Add jar to executors - hiveContext.sparkContext.addJar(path) + sqlContext.addJar(path) Seq(Row(0)) } @@ -122,7 +105,7 @@ case class AddFile(path: String) extends RunnableCommand { private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], @@ -130,9 +113,24 @@ case class CreateMetastoreDataSource( managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -144,13 +142,13 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, Array.empty[String], provider, @@ -163,7 +161,7 @@ case class CreateMetastoreDataSource( private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -171,19 +169,34 @@ case class CreateMetastoreDataSourceAsSelect( query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(Seq(tableName))) { + if (sqlContext.catalog.tableExists(tableIdent)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -200,8 +213,8 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { - case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table $tableName because the resolved relation does not " + @@ -249,7 +262,7 @@ case class CreateMetastoreDataSourceAsSelect( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, Some(resolved.relation.schema), partitionColumns, provider, @@ -258,7 +271,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.refreshTable(tableName) + hiveContext.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index abe5c6900313..a1787fc92d6d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} @@ -37,41 +37,72 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ -private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) +private[hive] class HiveFunctionRegistry( + underlying: analysis.FunctionRegistry, + executionHive: ClientWrapper) extends analysis.FunctionRegistry with HiveInspectors { - def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) + def getFunctionInfo(name: String): FunctionInfo = { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { + FunctionRegistry.getFunctionInfo(name) + } + } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { Try(underlying.lookupFunction(name, children)).getOrElse { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + Option(getFunctionInfo(name.toLowerCase)).getOrElse( throw new AnalysisException(s"undefined function $name")) val functionClassName = functionInfo.getFunctionClass.getName - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAF(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions + // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we + // catch the exception and throw AnalysisException instead. + try { + if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF( + new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAFFunction( + new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + udtf.elementTypes // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } catch { + case analysisException: AnalysisException => + // If the exception is an AnalysisException, just throw it. + throw analysisException + case throwable: Throwable => + // If there is any other error, we throw an AnalysisException. + val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + + s"because: ${throwable.getMessage}." + throw new AnalysisException(errorMessage) } } } @@ -81,15 +112,14 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) /* List all of the registered function names. */ override def listFunction(): Seq[String] = { - val a = FunctionRegistry.getFunctionNames ++ underlying.listFunction() - a.toList.sorted + (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted } /* Get the class of the registered function by specified name. */ override def lookupFunction(name: String): Option[ExpressionInfo] = { underlying.lookupFunction(name).orElse( Try { - val info = FunctionRegistry.getFunctionInfo(name) + val info = getFunctionInfo(name) val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) if (annotation != null) { Some(new ExpressionInfo( @@ -98,7 +128,11 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) annotation.value(), annotation.extended())) } else { - None + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + name, + null, + null)) } }.getOrElse(None)) } @@ -116,7 +150,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient private lazy val arguments = children.map(toInspector).toArray @@ -133,8 +167,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) + override val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -205,7 +238,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - lazy val dataType: DataType = inspectorToDataType(returnInspector) + override val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -227,286 +260,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr } } -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - case WindowExpression( - UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getfInfo().getFunctionClass - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors with Unevaluable { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - override def dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - override def nullable: Boolean = true - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap( - inputProjection(input), - inputInspectors, - new Array[AnyRef](children.length), - inputDataTypes) - } - - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) - } - - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a ArrayData having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[ArrayData].get(index, dataType) - } - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - override def newInstance(): WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) -} - -private[hive] case class HiveGenericUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = AbstractGenericUDAFResolver - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = UDAF - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[Generator]]. Note that the semantics of Generators do not allow @@ -542,8 +295,8 @@ private[hive] case class HiveGenericUDTF( @transient protected lazy val collector = new UDTFCollector - lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { - field => (inspectorToDataType(field.getFieldObjectInspector), true) + override lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true, field.getFieldName) } @transient @@ -586,49 +339,94 @@ private[hive] case class HiveGenericUDTF( } } +/** + * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt + * performance a lot. + */ private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression1, - isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction1 - with HiveInspectors { + children: Seq[Expression], + isUDAFBridgeRequired: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate with HiveInspectors { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - def this() = this(null, null, null) + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) - private val resolver = + @transient + private lazy val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - private val inspectors = exprs.map(toInspector).toArray + @transient + private lazy val inspectors = children.map(toInspector).toArray - private val function = { + @transient + private lazy val functionAndInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + val f = resolver.getEvaluator(parameterInfo) + f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) } - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + @transient + private lazy val function = functionAndInspector._1 + + @transient + private lazy val returnInspector = functionAndInspector._2 - private val buffer = - function.getNewAggregationBuffer + @transient + private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new InterpretedProjection(exprs) + private lazy val inputProjection = new InterpretedProjection(children) @transient - protected lazy val cached = new Array[AnyRef](exprs.length) + private lazy val cached = new Array[AnyRef](children.length) @transient - private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - def update(input: InternalRow): Unit = { + // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation + // buffer for it. + override def aggBufferSchema: StructType = StructType(Nil) + + override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "Hive UDAF doesn't support partial aggregate") + } + + override def initialize(_buffer: MutableRow): Unit = { + buffer = function.getNewAggregationBuffer + } + + override val aggBufferAttributes: Seq[AttributeReference] = Nil + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = Nil + + // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our + // catalyst type checking framework. + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + override def nullable: Boolean = true + + override def supportsPartial: Boolean = false + + override val dataType: DataType = inspectorToDataType(returnInspector) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8850e060d2a7..93c016b6c6c7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -26,13 +26,12 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -45,7 +44,7 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc) extends Logging with SparkHadoopMapRedUtil @@ -56,7 +55,7 @@ private[hive] class SparkHiveWriterContainer( // Add table properties from storage handler to jobConf, so any custom storage // handler settings can be set to jobConf if (tableDesc != null) { - PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) + HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, jobConf, false) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } protected val conf = new SerializableJobConf(jobConf) @@ -122,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -163,7 +162,7 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { } private[spark] class SparkHiveDynamicPartitionWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { @@ -171,7 +170,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ @@ -194,10 +193,10 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then // load it with loadDynamicPartitions/loadPartition/loadTable. - val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + val oldMarker = conf.value.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) super.commitJob() - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } override def getLocalFileWriter(row: InternalRow, schema: StructType) @@ -211,18 +210,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index ddd5d24717ad..27193f54d3a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable @@ -31,15 +31,17 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgument.FACTORY.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgument.FACTORY.newBuilder() + def newBuilder = SearchArgumentFactory.newBuilder() def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. @@ -102,41 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) + + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 6fa599734892..1136670b7a0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -23,16 +23,17 @@ import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} +import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -44,11 +45,11 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -/* Implicit conversions */ -import scala.collection.JavaConversions._ +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { -private[sql] class DefaultSource extends HadoopFsRelationProvider { - def createRelation( + override def shortName(): String = "orc" + + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -66,7 +67,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -76,7 +77,8 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - serde.initialize(context.getConfiguration, table) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + serde.initialize(configuration, table) serde } @@ -86,20 +88,10 @@ private[orc] class OrcOutputWriter( TypeInfoUtils.getTypeInfoFromTypeString( HiveMetastoreTypes.toMetastoreType(dataSchema)) - TypeInfoUtils - .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) - .asInstanceOf[StructObjectInspector] + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] } - // Used to hold temporary `Writable` fields of the next row to be written. - private val reusableOutputBuffer = new Array[Any](dataSchema.length) - - // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.zip(dataSchema.fields.map(_.dataType)) - .map { case (ref, dt) => - wrapperFor(ref.getFieldObjectInspector, dt) - }.toArray - // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. private var recordWriterInstantiated = false @@ -107,9 +99,10 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = context.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val partition = context.getTaskAttemptID.getTaskID.getId + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val partition = taskAttemptId.getTaskID.getId val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( @@ -120,16 +113,34 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def writeInternal(row: InternalRow): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + private def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) + while (i < fieldRefs.size) { + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrap( + row.get(i, dataSchema(i).dataType), + fieldRefs.get(i).getFieldObjectInspector, + dataSchema(i).dataType)) i += 1 } + } + + val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + wrapOrcStruct(cachedOrcStruct, structOI, row) recordWriter.write( NullWritable.get(), - serializer.serialize(reusableOutputBuffer, structOI)) + serializer.serialize(cachedOrcStruct, structOI)) } override def close(): Unit = { @@ -139,7 +150,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], @@ -147,7 +157,7 @@ private[sql] class OrcRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( @@ -188,16 +198,17 @@ private[sql] class OrcRelation( partitionColumns) } - override def buildScan( + override private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[FileStatus]): RDD[Row] = { + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] + OrcTableScan(output, this, filters, inputPaths).execute() } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - job.getConfiguration match { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -242,18 +253,19 @@ private[orc] case class OrcTableScan( path: String, conf: Configuration, iterator: Iterator[Writable], - nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[InternalRow] = { + nonPartitionKeyAttrs: Seq[Attribute]): Iterator[InternalRow] = { val deserializer = new OrcSerde val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(StructType.fromAttributes(nonPartitionKeyAttrs)) // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty // partition since we know that this file is empty. maybeStructOI.map { soi => - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.zipWithIndex.map { case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + soi.getStructFieldRef(attr.name) -> ordinal }.unzip val unwrappers = fieldRefs.map(unwrapperFor) // Map each tuple to a row object @@ -269,7 +281,7 @@ private[orc] case class OrcTableScan( } i += 1 } - mutableRow: InternalRow + unsafeProjection(mutableRow) } }.getOrElse { Iterator.empty @@ -278,7 +290,7 @@ private[orc] case class OrcTableScan( def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { @@ -291,9 +303,11 @@ private[orc] case class OrcTableScan( // Sets requested columns addColumnIds(attributes, relation, conf) - if (inputPaths.nonEmpty) { - FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sqlContext.sparkContext.emptyRDD[InternalRow] } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) val inputFormatClass = classOf[OrcInputFormat] @@ -309,13 +323,8 @@ private[orc] case class OrcTableScan( val wrappedConf = new SerializableConfiguration(conf) rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) - fillObject( - split.getPath.toString, - wrappedConf.value, - iterator.map(_._2), - attributes.zipWithIndex, - mutableRow) + val writableIterator = iterator.map(_._2) + fillObject(split.getPath.toString, wrappedConf.value, writableIterator, attributes) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7bbdef90cd6b..97792549bb7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,32 +20,24 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import org.apache.hadoop.hive.conf.HiveConf +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.language.implicitConversions + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry -import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.SQLConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} -import scala.collection.mutable -import scala.language.implicitConversions - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( @@ -56,10 +48,14 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) +trait TestHiveSingleton { + protected val sqlContext: SQLContext = TestHive + protected val hiveContext: TestHiveContext = TestHive +} + /** * A locally running test instance of Spark's Hive execution engine. * @@ -83,15 +79,25 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir() + lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") + + lazy val scratchDirPath = { + val dir = Utils.createTempDir(namePrefix = "scratch-") + dir.delete() + dir + } - private lazy val temporaryConfig = newTemporaryConfiguration() + private lazy val temporaryConfig = newTemporaryConfiguration(useInMemoryDerby = false) /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = - temporaryConfig ++ Map( - ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toString, - ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true") + protected override def configure(): Map[String, String] = { + super.configure() ++ temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } val testTempDir = Utils.createTempDir() @@ -110,18 +116,19 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } + protected[sql] override lazy val conf: SQLConf = new SQLConf { + // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" + override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - protected[hive] class SQLSession extends super.SQLSession { - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. - // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" - override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + clear() + + override def clear(): Unit = { + super.clear() + + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } } } @@ -149,7 +156,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) @@ -184,7 +191,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } val referencedTestTables = referencedTables.filter(testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) @@ -244,7 +251,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { }), TestTable("src_thrift", () => { import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol @@ -253,7 +259,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( - | 'serialization.class'='${classOf[Complex].getName}', + | 'serialization.class'='org.apache.spark.sql.hive.test.Complex', | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS @@ -272,10 +278,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", @@ -308,10 +311,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", @@ -369,7 +369,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) SELECT title, air_date, doctor FROM episodes """.cmd - ) + ), + TestTable("src_json", + s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) @@ -405,7 +409,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } @@ -415,9 +419,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } + FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). + foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. hiveconf.set("fs.default.name", new File(".").toURI.toString) @@ -437,18 +440,24 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } + defaultOverrides() runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) } } } + +private[hive] object TestHiveContext { + + /** + * A map used to store all confs that need to be overridden in sql/hive unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java similarity index 64% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 741a3cd31c60..b4bf9eef8fca 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.IOException; import java.util.ArrayList; @@ -29,8 +29,10 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; -import org.apache.spark.sql.hive.HiveContext; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { private transient JavaSparkContext sc; @@ -38,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -50,11 +52,11 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } - df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df = hc.read().json(sc.parallelize(jsonObjects)); df.registerTempTable("window_table"); } @@ -69,7 +71,7 @@ public void tearDown() throws IOException { @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -77,4 +79,26 @@ public void saveTableAndQueryIt() { " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); } + + @Test + public void testUDAF() { + DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + UserDefinedAggregateFunction udaf = new MyDoubleSum(); + UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); + // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if + // we want to use distinct aggregation. + DataFrame aggregatedDF = + df.groupBy() + .agg( + udaf.distinct(col("value")), + udaf.apply(col("value")), + registeredUDAF.apply(col("value")), + callUDF("mydoublesum", col("value"))); + + List expectedResult = new ArrayList<>(); + expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); + checkAnswer( + aggregatedDF, + expectedResult); + } } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 88% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 64d1ce92931e..8c4af1b8eaf4 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.spark.sql.SaveMode; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -37,11 +36,12 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; -import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.util.Utils; public class JavaMetastoreDataSourcesSuite { @@ -54,7 +54,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -72,13 +72,14 @@ public void setUp() throws IOException { if (path.exists()) { path.delete(); } - hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath( + new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); if (fs.exists(hiveManagedPath)){ fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -90,13 +91,15 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); - sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + if (sqlContext != null) { + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } } @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -119,7 +122,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -131,7 +134,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -147,7 +150,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java similarity index 98% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index a2247e3da155..5a167edd8959 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive.aggregate; +package org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; @@ -65,7 +65,7 @@ public MyDoubleAvg() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java similarity index 97% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index da29e24d267d..c3b7768e71bf 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive.aggregate; +package org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; @@ -60,7 +60,7 @@ public MyDoubleSum() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java new file mode 100644 index 000000000000..4ef1f276d1bb --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -0,0 +1,1173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.test; + +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.EncodingUtils; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; + +/** + * This is a fork of Hive 0.13's org/apache/hadoop/hive/serde2/thrift/test/Complex.java, which + * does not contain union fields that are not supported by Spark SQL. + */ + +@SuppressWarnings({"ALL", "unchecked"}) +public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); + + private static final org.apache.thrift.protocol.TField AINT_FIELD_DESC = new org.apache.thrift.protocol.TField("aint", org.apache.thrift.protocol.TType.I32, (short)1); + private static final org.apache.thrift.protocol.TField A_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("aString", org.apache.thrift.protocol.TType.STRING, (short)2); + private static final org.apache.thrift.protocol.TField LINT_FIELD_DESC = new org.apache.thrift.protocol.TField("lint", org.apache.thrift.protocol.TType.LIST, (short)3); + private static final org.apache.thrift.protocol.TField L_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lString", org.apache.thrift.protocol.TType.LIST, (short)4); + private static final org.apache.thrift.protocol.TField LINT_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lintString", org.apache.thrift.protocol.TType.LIST, (short)5); + private static final org.apache.thrift.protocol.TField M_STRING_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("mStringString", org.apache.thrift.protocol.TType.MAP, (short)6); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); + } + + private int aint; // required + private String aString; // required + private List lint; // required + private List lString; // required + private List lintString; // required + private Map mStringString; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + AINT((short)1, "aint"), + A_STRING((short)2, "aString"), + LINT((short)3, "lint"), + L_STRING((short)4, "lString"), + LINT_STRING((short)5, "lintString"), + M_STRING_STRING((short)6, "mStringString"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // AINT + return AINT; + case 2: // A_STRING + return A_STRING; + case 3: // LINT + return LINT; + case 4: // L_STRING + return L_STRING; + case 5: // LINT_STRING + return LINT_STRING; + case 6: // M_STRING_STRING + return M_STRING_STRING; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __AINT_ISSET_ID = 0; + private byte __isset_bitfield = 0; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.LINT, new org.apache.thrift.meta_data.FieldMetaData("lint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.L_STRING, new org.apache.thrift.meta_data.FieldMetaData("lString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.LINT_STRING, new org.apache.thrift.meta_data.FieldMetaData("lintString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, IntString.class)))); + tmpMap.put(_Fields.M_STRING_STRING, new org.apache.thrift.meta_data.FieldMetaData("mStringString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Complex.class, metaDataMap); + } + + public Complex() { + } + + public Complex( + int aint, + String aString, + List lint, + List lString, + List lintString, + Map mStringString) + { + this(); + this.aint = aint; + setAintIsSet(true); + this.aString = aString; + this.lint = lint; + this.lString = lString; + this.lintString = lintString; + this.mStringString = mStringString; + } + + /** + * Performs a deep copy on other. + */ + public Complex(Complex other) { + __isset_bitfield = other.__isset_bitfield; + this.aint = other.aint; + if (other.isSetAString()) { + this.aString = other.aString; + } + if (other.isSetLint()) { + List __this__lint = new ArrayList(); + for (Integer other_element : other.lint) { + __this__lint.add(other_element); + } + this.lint = __this__lint; + } + if (other.isSetLString()) { + List __this__lString = new ArrayList(); + for (String other_element : other.lString) { + __this__lString.add(other_element); + } + this.lString = __this__lString; + } + if (other.isSetLintString()) { + List __this__lintString = new ArrayList(); + for (IntString other_element : other.lintString) { + __this__lintString.add(new IntString(other_element)); + } + this.lintString = __this__lintString; + } + if (other.isSetMStringString()) { + Map __this__mStringString = new HashMap(); + for (Map.Entry other_element : other.mStringString.entrySet()) { + + String other_element_key = other_element.getKey(); + String other_element_value = other_element.getValue(); + + String __this__mStringString_copy_key = other_element_key; + + String __this__mStringString_copy_value = other_element_value; + + __this__mStringString.put(__this__mStringString_copy_key, __this__mStringString_copy_value); + } + this.mStringString = __this__mStringString; + } + } + + public Complex deepCopy() { + return new Complex(this); + } + + @Override + public void clear() { + setAintIsSet(false); + this.aint = 0; + this.aString = null; + this.lint = null; + this.lString = null; + this.lintString = null; + this.mStringString = null; + } + + public int getAint() { + return this.aint; + } + + public void setAint(int aint) { + this.aint = aint; + setAintIsSet(true); + } + + public void unsetAint() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __AINT_ISSET_ID); + } + + /** Returns true if field aint is set (has been assigned a value) and false otherwise */ + public boolean isSetAint() { + return EncodingUtils.testBit(__isset_bitfield, __AINT_ISSET_ID); + } + + public void setAintIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __AINT_ISSET_ID, value); + } + + public String getAString() { + return this.aString; + } + + public void setAString(String aString) { + this.aString = aString; + } + + public void unsetAString() { + this.aString = null; + } + + /** Returns true if field aString is set (has been assigned a value) and false otherwise */ + public boolean isSetAString() { + return this.aString != null; + } + + public void setAStringIsSet(boolean value) { + if (!value) { + this.aString = null; + } + } + + public int getLintSize() { + return (this.lint == null) ? 0 : this.lint.size(); + } + + public java.util.Iterator getLintIterator() { + return (this.lint == null) ? null : this.lint.iterator(); + } + + public void addToLint(int elem) { + if (this.lint == null) { + this.lint = new ArrayList<>(); + } + this.lint.add(elem); + } + + public List getLint() { + return this.lint; + } + + public void setLint(List lint) { + this.lint = lint; + } + + public void unsetLint() { + this.lint = null; + } + + /** Returns true if field lint is set (has been assigned a value) and false otherwise */ + public boolean isSetLint() { + return this.lint != null; + } + + public void setLintIsSet(boolean value) { + if (!value) { + this.lint = null; + } + } + + public int getLStringSize() { + return (this.lString == null) ? 0 : this.lString.size(); + } + + public java.util.Iterator getLStringIterator() { + return (this.lString == null) ? null : this.lString.iterator(); + } + + public void addToLString(String elem) { + if (this.lString == null) { + this.lString = new ArrayList(); + } + this.lString.add(elem); + } + + public List getLString() { + return this.lString; + } + + public void setLString(List lString) { + this.lString = lString; + } + + public void unsetLString() { + this.lString = null; + } + + /** Returns true if field lString is set (has been assigned a value) and false otherwise */ + public boolean isSetLString() { + return this.lString != null; + } + + public void setLStringIsSet(boolean value) { + if (!value) { + this.lString = null; + } + } + + public int getLintStringSize() { + return (this.lintString == null) ? 0 : this.lintString.size(); + } + + public java.util.Iterator getLintStringIterator() { + return (this.lintString == null) ? null : this.lintString.iterator(); + } + + public void addToLintString(IntString elem) { + if (this.lintString == null) { + this.lintString = new ArrayList<>(); + } + this.lintString.add(elem); + } + + public List getLintString() { + return this.lintString; + } + + public void setLintString(List lintString) { + this.lintString = lintString; + } + + public void unsetLintString() { + this.lintString = null; + } + + /** Returns true if field lintString is set (has been assigned a value) and false otherwise */ + public boolean isSetLintString() { + return this.lintString != null; + } + + public void setLintStringIsSet(boolean value) { + if (!value) { + this.lintString = null; + } + } + + public int getMStringStringSize() { + return (this.mStringString == null) ? 0 : this.mStringString.size(); + } + + public void putToMStringString(String key, String val) { + if (this.mStringString == null) { + this.mStringString = new HashMap(); + } + this.mStringString.put(key, val); + } + + public Map getMStringString() { + return this.mStringString; + } + + public void setMStringString(Map mStringString) { + this.mStringString = mStringString; + } + + public void unsetMStringString() { + this.mStringString = null; + } + + /** Returns true if field mStringString is set (has been assigned a value) and false otherwise */ + public boolean isSetMStringString() { + return this.mStringString != null; + } + + public void setMStringStringIsSet(boolean value) { + if (!value) { + this.mStringString = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case AINT: + if (value == null) { + unsetAint(); + } else { + setAint((Integer)value); + } + break; + + case A_STRING: + if (value == null) { + unsetAString(); + } else { + setAString((String)value); + } + break; + + case LINT: + if (value == null) { + unsetLint(); + } else { + setLint((List)value); + } + break; + + case L_STRING: + if (value == null) { + unsetLString(); + } else { + setLString((List)value); + } + break; + + case LINT_STRING: + if (value == null) { + unsetLintString(); + } else { + setLintString((List)value); + } + break; + + case M_STRING_STRING: + if (value == null) { + unsetMStringString(); + } else { + setMStringString((Map)value); + } + break; + + default: + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case AINT: + return Integer.valueOf(getAint()); + + case A_STRING: + return getAString(); + + case LINT: + return getLint(); + + case L_STRING: + return getLString(); + + case LINT_STRING: + return getLintString(); + + case M_STRING_STRING: + return getMStringString(); + + default: + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case AINT: + return isSetAint(); + case A_STRING: + return isSetAString(); + case LINT: + return isSetLint(); + case L_STRING: + return isSetLString(); + case LINT_STRING: + return isSetLintString(); + case M_STRING_STRING: + return isSetMStringString(); + default: + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) { + return false; + } + if (that instanceof Complex) { + return this.equals((Complex)that); + } + return false; + } + + public boolean equals(Complex that) { + if (that == null) { + return false; + } + + boolean this_present_aint = true; + boolean that_present_aint = true; + if (this_present_aint || that_present_aint) { + if (!(this_present_aint && that_present_aint)) { + return false; + } + if (this.aint != that.aint) { + return false; + } + } + + boolean this_present_aString = true && this.isSetAString(); + boolean that_present_aString = true && that.isSetAString(); + if (this_present_aString || that_present_aString) { + if (!(this_present_aString && that_present_aString)) { + return false; + } + if (!this.aString.equals(that.aString)) { + return false; + } + } + + boolean this_present_lint = true && this.isSetLint(); + boolean that_present_lint = true && that.isSetLint(); + if (this_present_lint || that_present_lint) { + if (!(this_present_lint && that_present_lint)) { + return false; + } + if (!this.lint.equals(that.lint)) { + return false; + } + } + + boolean this_present_lString = true && this.isSetLString(); + boolean that_present_lString = true && that.isSetLString(); + if (this_present_lString || that_present_lString) { + if (!(this_present_lString && that_present_lString)) { + return false; + } + if (!this.lString.equals(that.lString)) { + return false; + } + } + + boolean this_present_lintString = true && this.isSetLintString(); + boolean that_present_lintString = true && that.isSetLintString(); + if (this_present_lintString || that_present_lintString) { + if (!(this_present_lintString && that_present_lintString)) { + return false; + } + if (!this.lintString.equals(that.lintString)) { + return false; + } + } + + boolean this_present_mStringString = true && this.isSetMStringString(); + boolean that_present_mStringString = true && that.isSetMStringString(); + if (this_present_mStringString || that_present_mStringString) { + if (!(this_present_mStringString && that_present_mStringString)) { + return false; + } + if (!this.mStringString.equals(that.mStringString)) { + return false; + } + } + + return true; + } + + @Override + public int hashCode() { + HashCodeBuilder builder = new HashCodeBuilder(); + + boolean present_aint = true; + builder.append(present_aint); + if (present_aint) { + builder.append(aint); + } + + boolean present_aString = true && (isSetAString()); + builder.append(present_aString); + if (present_aString) { + builder.append(aString); + } + + boolean present_lint = true && (isSetLint()); + builder.append(present_lint); + if (present_lint) { + builder.append(lint); + } + + boolean present_lString = true && (isSetLString()); + builder.append(present_lString); + if (present_lString) { + builder.append(lString); + } + + boolean present_lintString = true && (isSetLintString()); + builder.append(present_lintString); + if (present_lintString) { + builder.append(lintString); + } + + boolean present_mStringString = true && (isSetMStringString()); + builder.append(present_mStringString); + if (present_mStringString) { + builder.append(mStringString); + } + + return builder.toHashCode(); + } + + public int compareTo(Complex other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + Complex typedOther = (Complex)other; + + lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aint, typedOther.aint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetAString()).compareTo(typedOther.isSetAString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aString, typedOther.aString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLint()).compareTo(typedOther.isSetLint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lint, typedOther.lint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLString()).compareTo(typedOther.isSetLString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lString, typedOther.lString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLintString()).compareTo(typedOther.isSetLintString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLintString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lintString, typedOther.lintString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMStringString()).compareTo(typedOther.isSetMStringString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMStringString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.mStringString, typedOther.mStringString); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Complex("); + boolean first = true; + + sb.append("aint:"); + sb.append(this.aint); + first = false; + if (!first) { + sb.append(", "); + } + sb.append("aString:"); + if (this.aString == null) { + sb.append("null"); + } else { + sb.append(this.aString); + } + first = false; + if (!first) { + sb.append(", "); + } + sb.append("lint:"); + if (this.lint == null) { + sb.append("null"); + } else { + sb.append(this.lint); + } + first = false; + if (!first) { + sb.append(", "); + } + sb.append("lString:"); + if (this.lString == null) { + sb.append("null"); + } else { + sb.append(this.lString); + } + first = false; + if (!first) { + sb.append(", "); + } + sb.append("lintString:"); + if (this.lintString == null) { + sb.append("null"); + } else { + sb.append(this.lintString); + } + first = false; + if (!first) { + sb.append(", "); + } + sb.append("mStringString:"); + if (this.mStringString == null) { + sb.append("null"); + } else { + sb.append(this.mStringString); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ComplexStandardSchemeFactory implements SchemeFactory { + public ComplexStandardScheme getScheme() { + return new ComplexStandardScheme(); + } + } + + private static class ComplexStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // AINT + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // A_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // LINT + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.lint = new ArrayList(_list0.size); + for (int _i1 = 0; _i1 < _list0.size; ++_i1) + { + int _elem2; // required + _elem2 = iprot.readI32(); + struct.lint.add(_elem2); + } + iprot.readListEnd(); + } + struct.setLintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // L_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); + struct.lString = new ArrayList(_list3.size); + for (int _i4 = 0; _i4 < _list3.size; ++_i4) + { + String _elem5; // required + _elem5 = iprot.readString(); + struct.lString.add(_elem5); + } + iprot.readListEnd(); + } + struct.setLStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LINT_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); + struct.lintString = new ArrayList(_list6.size); + for (int _i7 = 0; _i7 < _list6.size; ++_i7) + { + IntString _elem8; // required + _elem8 = new IntString(); + _elem8.read(iprot); + struct.lintString.add(_elem8); + } + iprot.readListEnd(); + } + struct.setLintStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // M_STRING_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map9 = iprot.readMapBegin(); + struct.mStringString = new HashMap(2*_map9.size); + for (int _i10 = 0; _i10 < _map9.size; ++_i10) + { + String _key11; // required + String _val12; // required + _key11 = iprot.readString(); + _val12 = iprot.readString(); + struct.mStringString.put(_key11, _val12); + } + iprot.readMapEnd(); + } + struct.setMStringStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Complex struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(AINT_FIELD_DESC); + oprot.writeI32(struct.aint); + oprot.writeFieldEnd(); + if (struct.aString != null) { + oprot.writeFieldBegin(A_STRING_FIELD_DESC); + oprot.writeString(struct.aString); + oprot.writeFieldEnd(); + } + if (struct.lint != null) { + oprot.writeFieldBegin(LINT_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.lint.size())); + for (int _iter13 : struct.lint) + { + oprot.writeI32(_iter13); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lString != null) { + oprot.writeFieldBegin(L_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.lString.size())); + for (String _iter14 : struct.lString) + { + oprot.writeString(_iter14); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lintString != null) { + oprot.writeFieldBegin(LINT_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, struct.lintString.size())); + for (IntString _iter15 : struct.lintString) + { + _iter15.write(oprot); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.mStringString != null) { + oprot.writeFieldBegin(M_STRING_STRING_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, struct.mStringString.size())); + for (Map.Entry _iter16 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter16.getKey()); + oprot.writeString(_iter16.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ComplexTupleSchemeFactory implements SchemeFactory { + public ComplexTupleScheme getScheme() { + return new ComplexTupleScheme(); + } + } + + private static class ComplexTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetAint()) { + optionals.set(0); + } + if (struct.isSetAString()) { + optionals.set(1); + } + if (struct.isSetLint()) { + optionals.set(2); + } + if (struct.isSetLString()) { + optionals.set(3); + } + if (struct.isSetLintString()) { + optionals.set(4); + } + if (struct.isSetMStringString()) { + optionals.set(5); + } + oprot.writeBitSet(optionals, 6); + if (struct.isSetAint()) { + oprot.writeI32(struct.aint); + } + if (struct.isSetAString()) { + oprot.writeString(struct.aString); + } + if (struct.isSetLint()) { + { + oprot.writeI32(struct.lint.size()); + for (int _iter17 : struct.lint) + { + oprot.writeI32(_iter17); + } + } + } + if (struct.isSetLString()) { + { + oprot.writeI32(struct.lString.size()); + for (String _iter18 : struct.lString) + { + oprot.writeString(_iter18); + } + } + } + if (struct.isSetLintString()) { + { + oprot.writeI32(struct.lintString.size()); + for (IntString _iter19 : struct.lintString) + { + _iter19.write(oprot); + } + } + } + if (struct.isSetMStringString()) { + { + oprot.writeI32(struct.mStringString.size()); + for (Map.Entry _iter20 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter20.getKey()); + oprot.writeString(_iter20.getValue()); + } + } + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(6); + if (incoming.get(0)) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } + if (incoming.get(1)) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } + if (incoming.get(2)) { + { + org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.lint = new ArrayList(_list21.size); + for (int _i22 = 0; _i22 < _list21.size; ++_i22) + { + int _elem23; // required + _elem23 = iprot.readI32(); + struct.lint.add(_elem23); + } + } + struct.setLintIsSet(true); + } + if (incoming.get(3)) { + { + org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.lString = new ArrayList(_list24.size); + for (int _i25 = 0; _i25 < _list24.size; ++_i25) + { + String _elem26; // required + _elem26 = iprot.readString(); + struct.lString.add(_elem26); + } + } + struct.setLStringIsSet(true); + } + if (incoming.get(4)) { + { + org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + struct.lintString = new ArrayList(_list27.size); + for (int _i28 = 0; _i28 < _list27.size; ++_i28) + { + IntString _elem29; // required + _elem29 = new IntString(); + _elem29.read(iprot); + struct.lintString.add(_elem29); + } + } + struct.setLintStringIsSet(true); + } + if (incoming.get(5)) { + { + org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.mStringString = new HashMap(2*_map30.size); + for (int _i31 = 0; _i31 < _map30.size; ++_i31) + { + String _key32; // required + String _val33; // required + _key32 = iprot.readString(); + _val33 = iprot.readString(); + struct.mStringString.put(_key32, _val33); + } + } + struct.setMStringStringIsSet(true); + } + } + } + +} + diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 000000000000..ac6d11d8b919 --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf similarity index 100% rename from sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 rename to sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf diff --git a/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 new file mode 100644 index 000000000000..9a276bc794c0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 @@ -0,0 +1,3 @@ +476 +172 +622 diff --git a/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f new file mode 100644 index 000000000000..444039e75fba --- /dev/null +++ b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f @@ -0,0 +1,500 @@ +476 +172 +622 +54 +330 +818 +510 +556 +196 +968 +530 +386 +802 +300 +546 +448 +738 +132 +256 +426 +292 +812 +858 +748 +304 +938 +290 +990 +74 +654 +562 +554 +418 +30 +164 +806 +332 +834 +860 +504 +584 +438 +574 +306 +386 +676 +892 +918 +788 +474 +964 +348 +826 +988 +414 +398 +932 +416 +348 +798 +792 +494 +834 +978 +324 +754 +794 +618 +730 +532 +878 +684 +734 +650 +334 +390 +950 +34 +226 +310 +406 +678 +0 +910 +256 +622 +632 +114 +604 +410 +298 +876 +690 +258 +340 +40 +978 +314 +756 +442 +184 +222 +94 +144 +8 +560 +70 +854 +554 +416 +712 +798 +338 +764 +996 +250 +772 +874 +938 +384 +572 +374 +352 +108 +918 +102 +276 +206 +478 +426 +432 +860 +556 +352 +578 +442 +130 +636 +664 +622 +550 +274 +482 +166 +666 +360 +568 +24 +460 +362 +134 +520 +808 +768 +978 +706 +746 +544 +276 +434 +168 +696 +932 +116 +16 +822 +460 +416 +696 +48 +926 +862 +358 +344 +84 +258 +316 +238 +992 +0 +644 +394 +936 +786 +908 +200 +596 +398 +382 +836 +192 +52 +330 +654 +460 +410 +240 +262 +102 +808 +86 +872 +312 +938 +936 +616 +190 +392 +576 +962 +914 +196 +564 +394 +374 +636 +636 +818 +940 +274 +738 +632 +338 +826 +170 +154 +0 +980 +174 +728 +358 +236 +268 +790 +564 +276 +476 +838 +30 +236 +144 +180 +614 +38 +870 +20 +554 +546 +612 +448 +618 +778 +654 +484 +738 +784 +544 +662 +802 +484 +904 +354 +452 +10 +994 +804 +792 +634 +790 +116 +70 +672 +190 +22 +336 +68 +458 +466 +286 +944 +644 +996 +320 +390 +84 +642 +860 +238 +978 +916 +156 +152 +82 +446 +984 +298 +898 +436 +456 +276 +906 +60 +418 +128 +936 +152 +148 +684 +138 +460 +66 +736 +206 +592 +226 +432 +734 +688 +334 +548 +438 +478 +970 +232 +446 +512 +526 +140 +974 +960 +802 +576 +382 +10 +488 +876 +256 +934 +864 +404 +632 +458 +938 +926 +560 +4 +70 +566 +662 +470 +160 +88 +386 +642 +670 +208 +932 +732 +350 +806 +966 +106 +210 +514 +812 +818 +380 +812 +802 +228 +516 +180 +406 +524 +696 +848 +24 +792 +402 +434 +328 +862 +908 +956 +596 +250 +862 +328 +848 +374 +764 +10 +140 +794 +960 +582 +48 +702 +510 +208 +140 +326 +876 +238 +828 +400 +982 +474 +878 +720 +496 +958 +610 +834 +398 +888 +240 +858 +338 +886 +646 +650 +554 +460 +956 +356 +936 +620 +634 +666 +986 +920 +414 +498 +530 +960 +166 +272 +706 +344 +428 +924 +466 +812 +266 +350 +378 +908 +750 +802 +842 +814 +768 +512 +52 +268 +134 +768 +758 +36 +924 +984 +200 +596 +18 +682 +996 +292 +916 +724 +372 +570 +696 +334 +36 +546 +366 +562 +688 +194 +938 +630 +168 +56 +74 +896 +304 +696 +614 +388 +828 +954 +444 +252 +180 +338 +806 +800 +400 +194 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e new file mode 100644 index 000000000000..0bb9399af0c4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +306 0 +306 0 +306 0 +307 0 +307 0 +307 0 +307 0 +307 0 +307 0 +308 0 +308 0 +308 0 +309 0 +309 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd new file mode 100644 index 000000000000..4e455ed25511 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 new file mode 100644 index 000000000000..4e455ed25511 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f new file mode 100644 index 000000000000..4e455ed25511 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e new file mode 100644 index 000000000000..1cf253f92c05 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e @@ -0,0 +1 @@ +238 diff --git a/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 new file mode 100644 index 000000000000..60878ffb7706 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 similarity index 52% rename from sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 rename to sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 index 7c41615f8c18..a01c2622c68e 100644 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -1 +1 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 index d35bf9093ca9..2383bef94097 100644 --- a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 +++ b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 @@ -15,9 +15,9 @@ my_enum_structlist_map map from deserializer my_structlist array>> from deserializer my_enumlist array from deserializer -my_stringset struct<> from deserializer -my_enumset struct<> from deserializer -my_structset struct<> from deserializer +my_stringset array from deserializer +my_enumset array from deserializer +my_structset array>> from deserializer optionals struct<> from deserializer b string diff --git a/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f new file mode 100644 index 000000000000..1dcda4315a14 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f @@ -0,0 +1 @@ +{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025","fb:testid":"1234"} diff --git a/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc new file mode 100644 index 000000000000..81c545efebe5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc @@ -0,0 +1 @@ +1234 diff --git a/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a new file mode 100644 index 000000000000..99127db9e311 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a @@ -0,0 +1 @@ +amy {"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}} diff --git a/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb new file mode 100644 index 000000000000..0bc03998296a --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb @@ -0,0 +1 @@ +{"price":19.95,"color":"red"} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 new file mode 100644 index 000000000000..4f7e09bd3fa7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 @@ -0,0 +1 @@ +{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab new file mode 100644 index 000000000000..b2d212a597d9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab @@ -0,0 +1 @@ +reference ["reference","fiction","fiction"] ["0-553-21311-3","0-395-19395-8"] [{"age":25,"name":"bob"},{"age":26,"name":"jack"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da new file mode 100644 index 000000000000..21d88629fcdb --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da @@ -0,0 +1 @@ +25 [25,26] diff --git a/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 new file mode 100644 index 000000000000..e60721e1dd24 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 @@ -0,0 +1 @@ +2 [[1,2,{"b":"y","a":"x"}],[3,4],[5,6]] 1 [1,2,{"b":"y","a":"x"}] [1,2,{"b":"y","a":"x"},3,4,5,6] y ["y"] diff --git a/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 new file mode 100644 index 000000000000..356fcdf7139b --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b new file mode 100644 index 000000000000..ef4a39675ed6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b @@ -0,0 +1 @@ +94025 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d b/sql/hive/src/test/resources/golden/lateral_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d rename to sql/hive/src/test/resources/golden/lateral_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 b/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 new file mode 100644 index 000000000000..0da0d93886e0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 @@ -0,0 +1,2 @@ +key1 100 +key2 200 diff --git a/sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 b/sql/hive/src/test/resources/golden/lateral_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a similarity index 100% rename from sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 rename to sql/hive/src/test/resources/golden/lateral_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 b/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 new file mode 100644 index 000000000000..0da0d93886e0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 @@ -0,0 +1,2 @@ +key1 100 +key2 200 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe b/sql/hive/src/test/resources/golden/lateral_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-6-16d227442dd775615c6ecfceedc6c612 b/sql/hive/src/test/resources/golden/lateral_view_noalias-6-16d227442dd775615c6ecfceedc6c612 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 b/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 new file mode 100644 index 000000000000..4ba46bbda5b0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 @@ -0,0 +1,2 @@ +key1 100 key1 100 +key2 200 key2 200 diff --git a/sql/hive/src/test/resources/golden/macro-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/macro-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/macro-1-5ff5e8795c13303db5d3ea88e1e918b6 b/sql/hive/src/test/resources/golden/macro-1-5ff5e8795c13303db5d3ea88e1e918b6 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-10-45148a37f6ee9cf498dc7308cbd81a1c b/sql/hive/src/test/resources/golden/macro-10-45148a37f6ee9cf498dc7308cbd81a1c new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-11-f55b8684c77f6eefc2618ba79e5e0587 b/sql/hive/src/test/resources/golden/macro-11-f55b8684c77f6eefc2618ba79e5e0587 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-11-f55b8684c77f6eefc2618ba79e5e0587 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/macro-12-62b999122975c2a5de8e49fee089c041 b/sql/hive/src/test/resources/golden/macro-12-62b999122975c2a5de8e49fee089c041 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-13-87e53d2b4c84098e662779e8f0a59084 b/sql/hive/src/test/resources/golden/macro-13-87e53d2b4c84098e662779e8f0a59084 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-13-87e53d2b4c84098e662779e8f0a59084 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/macro-14-3a31df84432674ad410f44b137e32c2d b/sql/hive/src/test/resources/golden/macro-14-3a31df84432674ad410f44b137e32c2d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-15-56966c45104c0d9bc407e79538c2c029 b/sql/hive/src/test/resources/golden/macro-15-56966c45104c0d9bc407e79538c2c029 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-16-56966c45104c0d9bc407e79538c2c029 b/sql/hive/src/test/resources/golden/macro-16-56966c45104c0d9bc407e79538c2c029 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-17-b3864f1d19fdb88b3b74f6d74a0ba548 b/sql/hive/src/test/resources/golden/macro-17-b3864f1d19fdb88b3b74f6d74a0ba548 new file mode 100644 index 000000000000..f599e28b8ab0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-17-b3864f1d19fdb88b3b74f6d74a0ba548 @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/macro-18-bddb2fe17cd4d850c4462b7eb2b9bc2a b/sql/hive/src/test/resources/golden/macro-18-bddb2fe17cd4d850c4462b7eb2b9bc2a new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-19-e3c828c372607b8bf7be00a99359b662 b/sql/hive/src/test/resources/golden/macro-19-e3c828c372607b8bf7be00a99359b662 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-2-fde44c7854a9897acb4c2f78f24c8eec b/sql/hive/src/test/resources/golden/macro-2-fde44c7854a9897acb4c2f78f24c8eec new file mode 100644 index 000000000000..b49805ff631c --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-2-fde44c7854a9897acb4c2f78f24c8eec @@ -0,0 +1 @@ +0.8807970779778823 diff --git a/sql/hive/src/test/resources/golden/macro-20-cb252a243d59809930a4ff371cbfa292 b/sql/hive/src/test/resources/golden/macro-20-cb252a243d59809930a4ff371cbfa292 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-21-cb252a243d59809930a4ff371cbfa292 b/sql/hive/src/test/resources/golden/macro-21-cb252a243d59809930a4ff371cbfa292 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-3-ddc4cb920b0a68e06551cd34ae4e29ff b/sql/hive/src/test/resources/golden/macro-3-ddc4cb920b0a68e06551cd34ae4e29ff new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-4-86292bbb7f147393c38bca051768dbda b/sql/hive/src/test/resources/golden/macro-4-86292bbb7f147393c38bca051768dbda new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-5-ca270bff813e5ab18a6a799016693aa8 b/sql/hive/src/test/resources/golden/macro-5-ca270bff813e5ab18a6a799016693aa8 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-6-8976be22af3aba0cc4905e014b4e24fe b/sql/hive/src/test/resources/golden/macro-6-8976be22af3aba0cc4905e014b4e24fe new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-7-decde0a59183a393e580941c633d3c5c b/sql/hive/src/test/resources/golden/macro-7-decde0a59183a393e580941c633d3c5c new file mode 100644 index 000000000000..0cfbf08886fc --- /dev/null +++ b/sql/hive/src/test/resources/golden/macro-7-decde0a59183a393e580941c633d3c5c @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/macro-8-3d25ffda9ab348f3e39ad967fc0e5020 b/sql/hive/src/test/resources/golden/macro-8-3d25ffda9ab348f3e39ad967fc0e5020 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/macro-9-db5f5172704da1e6dd5d59c136b83e7e b/sql/hive/src/test/resources/golden/macro-9-db5f5172704da1e6dd5d59c136b83e7e new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 b/sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 similarity index 100% rename from sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 rename to sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 index 501bb6ab32f2..7bb2c0ab4398 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` smallint, `value` float) COMMENT 'temporary table' diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 index 90f8415a1c6b..3cc1a57ee3a4 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_feng.tmp_showcrt`( +CREATE TABLE `tmp_feng.tmp_showcrt`( `key` string, `value` int) ROW FORMAT SERDE diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 index 4ee22e523031..b51c71a71f91 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 index 6fda2570b53f..29189e1d860a 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 index 3049cd6243ad..1b283db3e774 100644 --- a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 +++ b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 @@ -17,6 +17,7 @@ ^ abs acos +add_months and array array_contains @@ -29,6 +30,7 @@ base64 between bin case +cbrt ceil ceiling coalesce @@ -47,7 +49,11 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user date_add +date_format date_sub datediff day @@ -65,6 +71,7 @@ ewah_bitmap_empty ewah_bitmap_or exp explode +factorial field find_in_set first_value @@ -73,6 +80,7 @@ format_number from_unixtime from_utc_timestamp get_json_object +greatest hash hex histogram_numeric @@ -81,6 +89,7 @@ if in in_file index +initcap inline instr isnotnull @@ -88,10 +97,13 @@ isnull java_method json_tuple lag +last_day last_value lcase lead +least length +levenshtein like ln locate @@ -109,11 +121,15 @@ max min minute month +months_between named_struct negative +next_day ngrams noop +noopstreaming noopwithmap +noopwithmapstreaming not ntile nvl @@ -147,10 +163,14 @@ rpad rtrim second sentences +shiftleft +shiftright +shiftrightunsigned sign sin size sort_array +soundex space split sqrt @@ -170,6 +190,7 @@ to_unix_timestamp to_utc_timestamp translate trim +trunc ucase unbase64 unhex diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae index 0f6cc6f44f1f..fdf701f96280 100644 --- a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -1 +1 @@ -Table tmpfoo does not have property: bar +Table default.tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 index 3c91e138d7bd..d8ec084f0b2b 100644 --- a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 +++ b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 @@ -1,5 +1,5 @@ date_add(start_date, num_days) - Returns the date that is num_days after start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_add('2009-30-07', 1) FROM src LIMIT 1; - '2009-31-07' + > SELECT date_add('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-31' diff --git a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 index 29d663f35c58..169c50003625 100644 --- a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 +++ b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 @@ -1,5 +1,5 @@ date_sub(start_date, num_days) - Returns the date that is num_days before start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_sub('2009-30-07', 1) FROM src LIMIT 1; - '2009-29-07' + > SELECT date_sub('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-29' diff --git a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 index 7ccaee7ad3bd..42197f7ad3e5 100644 --- a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 +++ b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 @@ -1,5 +1,5 @@ datediff(date1, date2) - Returns the number of days between date1 and date2 date1 and date2 are strings in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time parts are ignored.If date1 is earlier than date2, the result is negative. Example: - > SELECT datediff('2009-30-07', '2009-31-07') FROM src LIMIT 1; + > SELECT datediff('2009-07-30', '2009-07-31') FROM src LIMIT 1; 1 diff --git a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 index d4017178b4e6..09703d10eab7 100644 --- a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 +++ b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 @@ -1 +1 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 index 6135aafa5086..7c0ec1dc3be5 100644 --- a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 +++ b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 @@ -1,6 +1,9 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: dayofmonth -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT day('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT day('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 index 47a7018d9d5a..c37eb0ec2e96 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 @@ -1 +1 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 index d9490e20a3b6..9e931f649914 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 @@ -1,6 +1,9 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: day -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT dayofmonth('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT dayofmonth('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 index d54ebfbd6fb1..a529b107ff21 100644 --- a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 +++ b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 @@ -1,2 +1,2 @@ std(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, stddev +Synonyms: stddev, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d index 5f674788180e..ac3176a38254 100644 --- a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d +++ b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d @@ -1,2 +1,2 @@ stddev(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, std +Synonyms: std, stddev_pop diff --git a/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c b/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 b/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 b/sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 rename to sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index 92eaf1f2795b..fea3404769d9 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -48,9 +48,14 @@ log4j.logger.hive.log=OFF log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR - diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q index 9e036c1a91d3..e911fbf2d2c5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q @@ -5,6 +5,6 @@ SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY key, value)t ORDER BY ke SELECT key, value FROM src CLUSTER BY (key, value); -SELECT key, value FROM src ORDER BY (key ASC, value ASC); +SELECT key, value FROM src ORDER BY key ASC, value ASC; SELECT key, value FROM src SORT BY (key, value); SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY (key, value))t ORDER BY key, value; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q index b26a2e2799f7..a989800cbf85 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q @@ -1,42 +1,41 @@ +-- SORT_QUERY_RESULTS explain SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; CREATE TABLE union_out (id int); -insert overwrite table union_out +insert overwrite table union_out SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; -select * from union_out cluster by id; +select * from union_out; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 39d315aaeab5..99478e82d419 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest { +class CachedTableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -95,18 +96,18 @@ class CachedTableSuite extends QueryTest { test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestHive.uncacheTable("src") + hiveContext.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.sql("CACHE TABLE src") + sql("CACHE TABLE src") assertCached(table("src")) - assert(TestHive.isCached("src"), "Table 'src' should be cached") + assert(hiveContext.isCached("src"), "Table 'src' should be cached") - TestHive.sql("UNCACHE TABLE src") + sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + assert(!hiveContext.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { @@ -203,4 +204,14 @@ class CachedTableSuite extends QueryTest { sql("DROP TABLE refreshTable") Utils.deleteRecursively(tempPath) } + + test("SPARK-11246 cache parquet table") { + sql("CREATE TABLE cachedTable STORED AS PARQUET AS SELECT 1") + + cacheTable("cachedTable") + val sparkPlan = sql("SELECT * FROM cachedTable").queryExecution.sparkPlan + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 1) + + sql("DROP TABLE cachedTable") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala new file mode 100644 index 000000000000..34b2edb44b03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.SparkFunSuite + +/** + * Verify that some classes load and that others are not found on the classpath. + * + * + * This is used to detect classpath and shading conflict, especially between + * Spark's required Kryo version and that which can be found in some Hive versions. + */ +class ClasspathDependenciesSuite extends SparkFunSuite { + private val classloader = this.getClass.getClassLoader + + private def assertLoads(classname: String): Unit = { + val resourceURL: URL = Option(findResource(classname)).getOrElse { + fail(s"Class $classname not found as ${resourceName(classname)}") + } + + logInfo(s"Class $classname at $resourceURL") + classloader.loadClass(classname) + } + + private def assertLoads(classes: String*): Unit = { + classes.foreach(assertLoads) + } + + private def findResource(classname: String): URL = { + val resource = resourceName(classname) + classloader.getResource(resource) + } + + private def resourceName(classname: String): String = { + classname.replace(".", "/") + ".class" + } + + private def assertClassNotFound(classname: String): Unit = { + Option(findResource(classname)).foreach { resourceURL => + fail(s"Class $classname found at $resourceURL") + } + + intercept[ClassNotFoundException] { + classloader.loadClass(classname) + } + } + + private def assertClassNotFound(classes: String*): Unit = { + classes.foreach(assertClassNotFound) + } + + private val KRYO = "com.esotericsoftware.kryo.Kryo" + + private val SPARK_HIVE = "org.apache.hive." + private val SPARK_SHADED = "org.spark-project.hive.shaded." + + test("shaded Protobuf") { + assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + } + + test("hive-common") { + assertLoads("org.apache.hadoop.hive.conf.HiveConf") + } + + test("hive-exec") { + assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") + } + + private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" + + test("unshaded kryo") { + assertLoads(KRYO, STD_INSTANTIATOR) + } + + test("Forbidden Dependencies") { + assertClassNotFound( + SPARK_HIVE + KRYO, + SPARK_SHADED + KRYO, + "org.apache.hive." + KRYO, + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR + ) + } + + test("parquet-hadoop-bundle") { + assertLoads( + "parquet.hadoop.ParquetOutputFormat", + "parquet.hadoop.ParquetInputFormat" + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 30f5313d2b81..cf737836939f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,12 +22,12 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -122,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { test(name) { val error = intercept[AnalysisException] { - quietly(sql(query)) + quietly(hiveContext.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index fb10f8583da9..9864acf76526 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,26 +17,27 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { +class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext.implicits._ + import hiveContext.sql + private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") - TestHive.registerDataFrameAsTable(testData, "mytable") + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + hiveContext.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - TestHive.dropTempTable("mytable") + hiveContext.dropTempTable("mytable") } test("rollup") { @@ -51,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 52e782768cb7..f621367eb553 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest { +class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala new file mode 100644 index 000000000000..7fdc5d71937f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.QueryTest + +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { + test("table name with schema") { + // regression test for SPARK-11778 + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c int)") + hiveContext.read.table("usrdb.test") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 99e95fb92130..8bb9058cd74e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row @@ -133,8 +134,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,8 +212,10 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 983c013bcf86..d63f3d399652 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,31 +17,148 @@ package org.apache.spark.sql.hive -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive +import java.io.File -import org.apache.spark.sql.test.ExamplePointUDT -import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, QueryTest, Row, SaveMode} -class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { + import hiveContext.implicits._ test("struct field should accept underscore in sub-column name") { - val metastr = "struct" - - val datatype = HiveMetastoreTypes.toDataType(metastr) - assert(datatype.isInstanceOf[StructType]) + val hiveTypeStr = "struct" + val dateType = HiveMetastoreTypes.toDataType(hiveTypeStr) + assert(dateType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { val udt = new ExamplePointUDT - assert(HiveMetastoreTypes.toMetastoreType(udt) === - HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + assertResult(HiveMetastoreTypes.toMetastoreType(udt.sqlType)) { + HiveMetastoreTypes.toMetastoreType(udt) + } } test("duplicated metastore relations") { - import TestHive.implicits._ - val df = TestHive.sql("SELECT * FROM src") + val df = hiveContext.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } + +class DataSourceWithHiveMetastoreCatalogSuite + extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import testImplicits._ + + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) + + Seq( + "parquet" -> ( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + + "orc" -> ( + "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcSerde" + ) + ).foreach { case (provider, (inputFormat, outputFormat, serde)) => + test(s"Persist non-partitioned $provider relation into metastore as managed table") { + withTable("t") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + } + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(!hiveTable.isPartitioned) + assert(hiveTable.tableType === ManagedTable) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + + test(s"Persist non-partitioned $provider relation into metastore as external table") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalFile + + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + } + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + } + + test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING $provider + |OPTIONS (path '$path') + |AS SELECT 1 AS d1, "val_1" AS d2 + """.stripMargin) + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.isPartitioned === false) + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.partitionColumns.length === 0) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278..5596ec6882ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ +class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton { test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -54,7 +51,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -67,7 +64,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f765395e148a..a330362b4e1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { - override def beforeAll() { - if (SessionState.get() == null) { - SessionState.start(new HiveConf()) - } - } - private def extractTableDesc(sql: String): (HiveTable, Boolean) = { HiveQl.createPlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) @@ -175,4 +170,30 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = HiveQl.parseSql( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 72b35959a491..53185fd7751e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp +import java.util.Date -import scala.sys.process.{ProcessLogger, Process} +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -34,12 +42,14 @@ import org.scalatest.time.SpanSugar._ class HiveSparkSubmitSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { // TODO: rewrite these or mark them as slow tests to be run sparingly - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -47,13 +57,16 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -65,6 +78,11 @@ class HiveSparkSubmitSuite "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.sql.hive.metastore.version=0.12", + "--conf", "spark.sql.hive.metastore.jars=maven", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -76,7 +94,38 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--class", "Main", + testJar) + runSparkSubmit(args) + } + + ignore("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-11009 fix wrong result of Window function in cluster mode") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_11009.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) runSparkSubmit(args) } @@ -84,23 +133,55 @@ class HiveSparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val process = Process( - Seq("./bin/spark-submit") ++ args, - new File(sparkHome), - "SPARK_TESTING" -> "1", - "SPARK_HOME" -> sparkHome - ).run(ProcessLogger( + val history = ArrayBuffer.empty[String] + val commands = Seq("./bin/spark-submit") ++ args + val commandLine = commands.mkString("'", "' '", "'") + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" // scalastyle:off println - (line: String) => { println(s"out> $line") }, - (line: String) => { println(s"err> $line") } + println(logLine) // scalastyle:on println - )) + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180 seconds) { process.exitValue() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t } finally { // Ensure we still kill the process in case it timed out process.destroy() @@ -114,6 +195,7 @@ object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") @@ -186,7 +268,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -205,6 +287,7 @@ object SparkSQLConfTest extends Logging { // For this simple test, we do not really clone this object. override def clone: SparkConf = this } + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) // Run a simple command to make sure all lazy vals in hiveContext get instantiated. @@ -212,3 +295,80 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven") + .set("spark.ui.enabled", "false")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + import hiveContext.implicits._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} + +object SPARK_11009 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + try { + val df = sqlContext.range(1 << 20) + val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) + val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) + val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + if (df3.rdd.count() != 0) { + throw new Exception("df3 should have 0 output row.") + } + } finally { + sparkContext.stop() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 508695919e9a..81ee9ba71beb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,32 +19,30 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/* Implicits */ -import org.apache.spark.sql.hive.test.TestHive._ - case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ + import hiveContext.sql - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - TestHive.reset() + hiveContext.reset() // Register the testData, which will be used in every test. testData.registerTempTable("testData") } @@ -83,9 +81,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Double create fails when allowExisting = false") { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - val message = intercept[QueryExecutionException] { + intercept[QueryExecutionException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - }.getMessage + } } test("Double create does not fail when allowExisting = true") { @@ -95,9 +93,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -113,6 +111,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() + val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + sql( s""" |CREATE TABLE table_with_partition(c1 string) @@ -145,7 +145,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList + val folders = dir.filter { e => e.isDirectory && !e.getName().startsWith(stagingDir) }.toList if (folders.isEmpty) { List(acc.reverse) } else { @@ -158,7 +158,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -166,8 +166,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -182,9 +182,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -199,9 +199,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -214,11 +214,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { } test("SPARK-5498:partition schema does not match table schema") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = TestHive.sparkContext.parallelize( + val testDatawithNull = hiveContext.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6..183aca29cf98 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,30 +19,27 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row -class ListTablesSuite extends QueryTest with BeforeAndAfterAll { +class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ - import org.apache.spark.sql.hive.test.TestHive.implicits._ - - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. - catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(TableIdentifier("ListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") @@ -55,7 +52,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), Row("hivelisttablessuitetable", false)) @@ -69,9 +65,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - checkAnswer( - allTables.filter("tableName = 'indblisttablessuitetable'"), - Row("indblisttablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 4fdf774ead75..f74eb1500b98 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,32 +17,28 @@ package org.apache.spark.sql.hive -import java.io.File +import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll - with Logging { - override val sqlContext = TestHive +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ var jsonFilePath: String = _ @@ -372,7 +368,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA |) """.stripMargin) - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val expectedPath = catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) @@ -463,24 +459,21 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA checkAnswer(sql("SELECT * FROM savedJsonTable"), df) - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") - } - // We can overwrite it. df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // When the save mode is Ignore, we will do nothing when the table already exists. df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") - assert(df.schema === table("savedJsonTable").schema) + // TODO in ResolvedDataSource, will convert the schema into nullable = true + // hence the df.schema is not exactly the same as table("savedJsonTable").schema + // assert(df.schema === table("savedJsonTable").schema) checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // Drop table will also delete the data. sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { - read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + intercept[IOException] { + read.json(catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) } } @@ -555,7 +548,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA "org.apache.spark.sql.json", schema, Map.empty[String, String]) - }.getMessage.contains("'path' must be specified for json data."), + }.getMessage.contains("key not found: path"), "We should complain that path is not specified.") } } @@ -578,7 +571,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation) => // OK + case LogicalRelation(p: ParquetRelation, _) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") } @@ -711,7 +704,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA // Manually create a metastore data source table. catalog.createDataSourceTable( - tableName = "wide_schema", + tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], provider = "json", @@ -741,7 +734,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA "EXTERNAL" -> "FALSE"), tableType = ManagedTable, serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) + "path" -> catalog.hiveDefaultTableFilePath(TableIdentifier(tableName)))) catalog.client.createTable(hiveTable) @@ -760,10 +753,15 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA invalidateTable(tableName) val metastoreTable = catalog.client.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + + val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt + assert(numPartCols == 2) + val actualPartitionColumns = StructType( - metastoreTable.partitionColumns.map(c => - StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + (0 until numPartCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.partCol.$index")) + }) // Make sure partition columns are correctly stored in metastore. assert( expectedPartitionColumns.sameType(actualPartitionColumns), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20..f16c257ab5ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} -class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive +class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + private lazy val df = sqlContext.range(10).coalesce(1) - import sqlContext.sql + private def checkTablePath(dbName: String, tableName: String): Unit = { + val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) + val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName - private val df = sqlContext.range(10).coalesce(1) + assert(metastoreTable.serdeProperties("path") === expectedPath) + } test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => @@ -38,6 +41,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") } } @@ -46,6 +51,58 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } } } @@ -60,6 +117,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -69,6 +128,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -131,7 +192,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { } } - test("Refreshes a table in a non-default database") { + test("Refreshes a table in a non-default database - with USE") { import org.apache.spark.sql.functions.lit withTempDatabase { db => @@ -152,8 +213,94 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + hiveContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) } } } } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + hiveContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index bb5f1febe9ad..49aab85cf1aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,76 +17,123 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.{Row, SQLConf, SQLContext} - -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { - import ParquetCompatibilityTest.makeNullable - - override val sqlContext: SQLContext = TestHive - - override protected def beforeAll(): Unit = { - super.beforeAll() - - withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { - sqlContext.sql( - s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '${parquetStore.getCanonicalPath}' - """.stripMargin) - - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") +import java.sql.Timestamp + +import org.apache.hadoop.hive.conf.HiveConf + +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { + /** + * Set the staging directory (and hence path to ignore Parquet files under) + * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. + */ + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + + override protected def logParquetSchema(path: String): Unit = { + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |$schema + """.stripMargin) + } + + private def testParquetHiveCompatibility(row: Row, hiveTypes: String*): Unit = { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // Hive columns are always nullable, so here we append a all-null row. + val rows = row :: Row(Seq.fill(row.length)(null): _*) :: Nil + + // Don't convert Hive metastore Parquet tables to let Hive write those Parquet files. + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + val fields = hiveTypes.zipWithIndex.map { case (typ, index) => s" col_$index $typ" } + + val ddl = + s"""CREATE TABLE parquet_compat( + |${fields.mkString(",\n")} + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin + + logInfo( + s"""Creating testing Parquet table with the following DDL: + |$ddl + """.stripMargin) + + sqlContext.sql(ddl) + + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } + + logParquetSchema(path) + + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), rows) + } } } } - override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") + test("simple primitives") { + testParquetHiveCompatibility( + Row(true, 1.toByte, 2.toShort, 3, 4.toLong, 5.1f, 6.1d, "foo"), + "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") } - test("Read Parquet file generated by parquet-hive") { - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) + test("SPARK-10177 timestamp") { + testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") + } - // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. - // Have to assume all BINARY values are strings here. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) - } + test("array") { + testParquetHiveCompatibility( + Row( + Seq[Integer](1: Integer, null, 2: Integer, null), + Seq[String]("foo", null, "bar", null), + Seq[Seq[Integer]]( + Seq[Integer](1: Integer, null), + Seq[Integer](2: Integer, null))), + "ARRAY", + "ARRAY", + "ARRAY>") } - def makeRows: Seq[Row] = { - (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + test("map") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + (1: Integer) -> "foo", + (2: Integer) -> null)), + "MAP") + } + // HIVE-11625: Parquet map entries with null keys are dropped by Hive + ignore("map entries with null keys") { + testParquetHiveCompatibility( Row( - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i.toByte: java.lang.Byte), - nullable((i + 1).toShort: java.lang.Short), - nullable(i + 2: Integer), - nullable(i.toLong * 10: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(Seq.tabulate(3)(n => s"arr_${i + n}")), - nullable(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap)) - } + Map[Integer, String]( + null.asInstanceOf[Integer] -> "bar", + null.asInstanceOf[Integer] -> null)), + "MAP") + } + + test("struct") { + testParquetHiveCompatibility( + Row(Row(1, Seq("foo", "bar", null))), + "STRUCT>") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 017bc2adc103..f542a5a02508 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,49 +19,49 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ -class QueryPartitionSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ - import ctx.sql - - test("SPARK-5068: query data when path doesn't exist"){ - val testData = ctx.sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bc72b0172a46..f775f1e95587 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,24 +17,16 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll - import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - private lazy val ctx: HiveContext = { - val ctx = org.apache.spark.sql.hive.test.TestHive - ctx.reset() - ctx.cacheTables = false - ctx - } - - import ctx.sql +class StatisticsSuite extends QueryTest with TestHiveSingleton { + import hiveContext.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -77,7 +69,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + hiveContext.catalog.lookupRelation(TableIdentifier(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -111,7 +103,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -122,9 +114,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - ctx.analyze("tempTable") + hiveContext.analyze("tempTable") } - ctx.catalog.unregisterTable(Seq("tempTable")) + hiveContext.catalog.unregisterTable(TableIdentifier("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -152,8 +144,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -164,8 +156,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) @@ -174,7 +166,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, - "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + "SortMergeJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } @@ -208,8 +200,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -222,8 +214,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d..3ab457681119 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ +class UDFSuite extends QueryTest with TestHiveSingleton { test("UDF case insensitive") { - ctx.udf.register("random0", () => { Math.random() }) - ctx.udf.register("RANDOM1", () => { Math.random() }) - ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + hiveContext.udf.register("random0", () => { Math.random() }) + hiveContext.udf.register("RANDOM1", () => { Math.random() }) + hiveContext.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 0efcf80bd4ea..5e7b93d45710 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import scala.collection.JavaConversions._ +import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants @@ -38,7 +38,7 @@ class FiltersSuite extends SparkFunSuite with Logging { private val varCharCol = new FieldSchema() varCharCol.setName("varchar") varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) + testTable.setPartCols(Collections.singletonList(varCharCol)) filterTest("string filter", (a("stringcol", StringType) > Literal("test")) :: Nil, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d48..502b240f3650 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.IntegerType +import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils /** @@ -31,12 +35,15 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() @@ -48,7 +55,11 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + val badClient = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -78,7 +89,11 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "13", + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -91,7 +106,13 @@ class VersionsSuite extends SparkFunSuite with Logging { versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client + System.gc() // Hack to avoid SEGV on some JVM versions. + client = + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } test(s"$version: createDatabase") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6f0db27775e4..5550198c02fb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,24 +17,115 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.TestHive +import scala.collection.JavaConverters._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} -import org.scalatest.BeforeAndAfterAll -import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.types._ -abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { - override val sqlContext = TestHive - import sqlContext.implicits._ + def inputSchema: StructType = schema - var originalUseAggregate2: Boolean = _ + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + +class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction { + + def inputSchema: StructType = StructType(Nil) + + def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer.update(0, 0L) + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0)) + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + } + + def evaluate(buffer: Row): Any = { + buffer.getLong(0) + } +} + +class LongProductSum extends UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) +} + +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -67,21 +158,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") + val data3 = Seq[(Seq[Integer], Integer, Integer)]( + (Seq[Integer](1, 1), 10, -10), + (Seq[Integer](null), -60, 60), + (Seq[Integer](1, 1), 30, -30), + (Seq[Integer](1), 30, 30), + (Seq[Integer](2), 1, 1), + (null, -10, 10), + (Seq[Integer](2, 3), -1, null), + (Seq[Integer](2, 3), 1, 1), + (Seq[Integer](2, 3, 4), null, 1), + (Seq[Integer](null), 100, -10), + (Seq[Integer](3), null, 3), + (null, null, null), + (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + data3.write.saveAsTable("agg3") + val emptyDF = sqlContext.createDataFrame( - sqlContext.sparkContext.emptyRDD[Row], + sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") // Register UDAFs - sqlContext.udaf.register("mydoublesum", new MyDoubleSum) - sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.sql("DROP TABLE IF EXISTS agg3") sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -141,6 +249,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Nil) } + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + test("only do grouping") { checkAnswer( sqlContext.sql( @@ -185,6 +309,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(100, null) :: Row(null, 3) :: Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT key + |FROM agg3 + """.stripMargin), + Row(Seq[Integer](1, 1)) :: + Row(Seq[Integer](null)) :: + Row(Seq[Integer](1)) :: + Row(Seq[Integer](2)) :: + Row(null) :: + Row(Seq[Integer](2, 3)) :: + Row(Seq[Integer](2, 3, 4)) :: + Row(Seq[Integer](3)) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg3 + |GROUP BY value1, key + """.stripMargin), + Row(10, Seq[Integer](1, 1)) :: + Row(-60, Seq[Integer](null)) :: + Row(30, Seq[Integer](1, 1)) :: + Row(30, Seq[Integer](1)) :: + Row(1, Seq[Integer](2)) :: + Row(-10, null) :: + Row(-1, Seq[Integer](2, 3)) :: + Row(1, Seq[Integer](2, 3)) :: + Row(null, Seq[Integer](2, 3, 4)) :: + Row(100, Seq[Integer](null)) :: + Row(null, Seq[Integer](3)) :: + Row(null, null) :: Nil) } test("case in-sensitive resolution") { @@ -242,6 +401,15 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be """.stripMargin), Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + checkAnswer( + sqlContext.sql( + """ + |SELECT key, mean(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + checkAnswer( sqlContext.sql( """ @@ -266,13 +434,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) + } - checkAnswer( - sqlContext.sql( - """ - |SELECT avg(null) - """.stripMargin), - Row(null) :: Nil) + test("first_value and last_value") { + // We force to use a single partition for the sort and aggregate to make result + // deterministic. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key) tmp + """.stripMargin), + Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp + """.stripMargin), + Row(3, null, 3, null, 3, 1, 3, 1) :: Nil) + } } test("udaf") { @@ -295,7 +494,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(null, null, 110.0, null, null, 10.0) :: Nil) } - test("non-AlgebraicAggregate aggreguate function") { + test("interpreted aggregate function") { checkAnswer( sqlContext.sql( """ @@ -320,7 +519,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(null) :: Nil) } - test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + test("interpreted and expression-based aggregation functions") { checkAnswer( sqlContext.sql( """ @@ -364,7 +563,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( sqlContext.sql( @@ -402,6 +601,67 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: Row(3, null, 3.0, null, null, null) :: Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + + test("single distinct multiple columns set") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1, value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3) :: + Row(1, 3) :: + Row(2, 1) :: + Row(3, 0) :: Nil) + } + + test("multiple distinct multiple columns sets") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1), + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | longProductSum(distinct value1, value2), + | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | longProductSum(value1, value2), + | count(*), + | count(1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } test("test count") { @@ -453,87 +713,249 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(0, null, 1, 1, null, 0) :: Nil) } - test("error handling") { - withSQLConf("spark.sql.useAggregate2" -> "false") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + + val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") + val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr4 == Row(null)) + + val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") + val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr5 - 1.0) < 1e-12) + val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr6 + 1.0) < 1e-12) + + // Test for udaf_corr in HiveCompatibilitySuite + // udaf_corr has been blacklisted due to numerical errors + // We test it here: + // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL + // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; => + // 1 NULL + // 2 NULL + // 3 NULL + // 4 NULL + // 5 NULL + // 6 NULL + // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323 + + val covar_tab = Seq[(Integer, Integer, Integer)]( + (1, null, 15), + (2, 3, null), + (3, 7, 12), + (4, 4, 14), + (5, 8, 17), + (6, 2, 11)).toDF("a", "b", "c") + + covar_tab.registerTempTable("covar_tab") + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 1 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a = 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a + """.stripMargin), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, null) :: + Row(5, null) :: + Row(6, null) :: Nil) + + val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) + } + + test("no aggregation function (SPARK-11486)") { + val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + .groupBy("s").count() + .groupBy().count() + checkAnswer(df, Row(20) :: Nil) + } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnswer) + ) } + } - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { + test("udaf without specifying inputSchema") { + withTempTable("noInputSchemaUDAF") { + sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + + val data = + Row(1, Seq(Row(1), Row(2), Row(3))) :: + Row(1, Seq(Row(4), Row(5), Row(6))) :: + Row(2, Seq(Row(-10))) :: Nil + val schema = + StructType( + StructField("key", IntegerType) :: + StructField("myArray", + ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) + sqlContext.createDataFrame( + sparkContext.parallelize(data, 2), + schema) + .registerTempTable("noInputSchemaUDAF") + + checkAnswer( sqlContext.sql( """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 + |SELECT key, noInputSchema(myArray) + |FROM noInputSchemaUDAF |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + """.stripMargin), + Row(1, 21) :: Row(2, -10) :: Nil) - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.Aggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) + checkAnswer( + sqlContext.sql( + """ + |SELECT noInputSchema(myArray) + |FROM noInputSchemaUDAF + """.stripMargin), + Row(11) :: Nil) } } } -class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ +class TungstenAggregationQuerySuite extends AggregationQuerySuite - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") - super.beforeAll() - } - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { -class TungstenAggregationQuerySuite extends AggregationQuerySuite { + override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) - var originalUnsafeEnabled: Boolean = _ + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() + fail(newErrorMessage) + case None => + } + } + } } - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index b0d3dd44daed..e38d1eb5779f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -25,8 +25,10 @@ class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val ts = - new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf)) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 638b9c810372..4455430aa727 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.hive.execution import java.io._ +import scala.util.control.NonFatal + import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -40,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -124,7 +126,7 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) new java.math.BigInteger(1, digest.digest).toString(16) } @@ -207,7 +209,11 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + def createQueryTest( + testCaseName: String, + sql: String, + reset: Boolean = true, + tryWithoutResettingFirst: Boolean = false) { // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows assert(!testCaseName.contains(":")) @@ -238,9 +244,6 @@ abstract class HiveComparisonTest test(testCaseName) { logDebug(s"=== HIVE TEST: $testCaseName ===") - // Clear old output for this testcase. - outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val sqlWithoutComment = sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") val allQueries = @@ -267,11 +270,32 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - try { + def doTest(reset: Boolean, isSpeculative: Boolean = false): Unit = { + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + if (reset) { TestHive.reset() } + // Many tests drop indexes on src and srcpart at the beginning, so we need to load those + // tables here. Since DROP INDEX DDL is just passed to Hive, it bypasses the analyzer and + // thus the tables referenced in those DDL commands cannot be extracted for use by our + // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES + // command expect these tables to exist. + val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + for (table <- Seq("src", "srcpart")) { + val hasMatchingQuery = queryList.exists { query => + val normalizedQuery = query.toLowerCase.stripSuffix(";") + normalizedQuery.endsWith(table) || + normalizedQuery.contains(s"from $table") || + normalizedQuery.contains(s"from default.$table") + } + if (hasShowTableCommand || hasMatchingQuery) { + TestHive.loadTestTable(table) + } + } + val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" @@ -386,20 +410,87 @@ abstract class HiveComparisonTest hiveCacheFiles.foreach(_.delete()) } + // If this query is reading other tables that were created during this test run + // also print out the query plans and results for those. + val computedTablesMessages: String = try { + val tablesRead = new TestHive.QueryExecution(query).executedPlan.collect { + case ts: HiveTableScan => ts.relation.tableName + }.toSet + + TestHive.reset() + val executions = queryList.map(new TestHive.QueryExecution(_)) + executions.foreach(_.toRdd) + val tablesGenerated = queryList.zip(executions).flatMap { + case (q, e) => e.executedPlan.collect { + case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + (q, e, i) + } + } + + tablesGenerated.map { case (hiveql, execution, insert) => + s""" + |=== Generated Table === + |$hiveql + |$execution + |== Results == + |${insert.child.execute().collect().mkString("\n")} + """.stripMargin + }.mkString("\n") + + } catch { + case NonFatal(e) => + logError("Failed to compute generated tables", e) + s"Couldn't compute dependent tables: $e" + } + val errorMessage = s""" |Results do not match for $testCaseName: |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} |$resultComparison + |$computedTablesMessages """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) - fail(errorMessage) + if (isSpeculative && !reset) { + fail("Failed on first run; retrying") + } else { + fail(errorMessage) + } } } // Touch passed file. new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } + + val canSpeculativelyTryWithoutReset: Boolean = { + val excludedSubstrings = Seq( + "into table", + "create table", + "drop index" + ) + !queryList.map(_.toLowerCase).exists { query => + excludedSubstrings.exists(s => query.contains(s)) + } + } + + try { + try { + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + doTest(reset = false, isSpeculative = true) + } else { + doTest(reset) + } + } catch { + case tf: org.scalatest.exceptions.TestFailedException => + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + logWarning("Test failed without reset(); retrying with reset()") + doTest(reset = true) + } else { + throw tf + } + } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90..a7b7ad009391 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") @@ -35,8 +37,7 @@ class HiveExplainSuite extends QueryTest { "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", - "== Physical Plan ==", - "Code Generation", "== RDD ==") + "== Physical Plan ==") } test("explain create table command") { @@ -74,4 +75,30 @@ class HiveExplainSuite extends QueryTest { "Limit", "src") } + + test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempTable("jt") { + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + hiveContext.read.json(rdd).registerTempTable("jt") + val outputs = sql( + s""" + |EXPLAIN EXTENDED + |CREATE TABLE t1 + |AS + |SELECT * FROM jt + """.stripMargin).collect().map(_.mkString).mkString + + val shouldContain = + "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: + "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: + "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + for (key <- shouldContain) { + assert(outputs.contains(key), s"$key doesn't exist in result") + } + + val physicalIndex = outputs.indexOf("== Physical Plan ==") + assert(!outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should not contain Subquery since it's eliminated by optimizer") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index efbef68cd444..0d4c7f86b315 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest { +class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ + test("SPARK-5324 query result of describe command") { - loadTestTable("src") + hiveContext.loadTestTable("src") // register a describe command to be a temp table sql("desc src").registerTempTable("mydesc") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index ba56a8a6b689..cd055f9eca37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,11 +21,11 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HivePlanTest extends QueryTest { - import TestHive._ - import TestHive.implicits._ +class HivePlanTest extends QueryTest with TestHiveSingleton { + import hiveContext.sql + import hiveContext.implicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index f7b37dae0a5f..f96c989c4614 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -59,7 +59,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) - createQueryTest(testCaseName, queriesString) + createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 11a843becce6..8a5acaf3e10b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,21 +18,22 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.{SparkException, SparkFiles} case class TestData(a: Int, b: String) @@ -52,14 +53,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) } override def afterAll() { @@ -69,15 +62,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TEMPORARY FUNCTION udtf_count2") } - createQueryTest("Test UDTF.close in Lateral Views", - """ - |SELECT key, cc - |FROM src LATERAL VIEW udtf_count2(value) dd AS cc - """.stripMargin, false) // false mean we have to keep the temp function in registry - - createQueryTest("Test UDTF.close in SELECT", - "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) - test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") @@ -85,6 +69,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + // Testing the Broadcast based join for cartesian join (cross join) + // We assume that the Broadcast Join Threshold will works since the src is a small table + private val spark_10484_1 = """ + | SELECT a.key, b.key + | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY b.key, a.key + | LIMIT 20 + """.stripMargin + private val spark_10484_2 = """ + | SELECT a.key, b.key + | FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_3 = """ + | SELECT a.key, b.key + | FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_4 = """ + | SELECT a.key, b.key + | FROM src a JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1", + spark_10484_1) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2", + spark_10484_2) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3", + spark_10484_3) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4", + spark_10484_4) + + test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") { + def assertBroadcastNestedLoopJoin(sqlText: String): Unit = { + assert(sql(sqlText).queryExecution.sparkPlan.collect { + case _: BroadcastNestedLoopJoin => 1 + }.nonEmpty) + } + + assertBroadcastNestedLoopJoin(spark_10484_1) + assertBroadcastNestedLoopJoin(spark_10484_2) + assertBroadcastNestedLoopJoin(spark_10484_3) + assertBroadcastNestedLoopJoin(spark_10484_4) + } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", """ SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP @@ -176,8 +212,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("! operator", """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table + | SELECT 1 AS a UNION ALL SELECT 2 AS a) t |WHERE !(a>1) """.stripMargin) @@ -214,12 +249,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -229,71 +269,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |FROM src LIMIT 1; """.stripMargin) - createQueryTest("count distinct 0 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 0) table - """.stripMargin) - - createQueryTest("count distinct 1 value strings", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL - | SELECT 'b' AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values including null", - """ - |SELECT COUNT(DISTINCT a, 1) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 2L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -510,7 +485,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) test("transform with SerDe2") { @@ -529,7 +504,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": |"src","type": "record","fields": [{"name":"key","type":"int"}]}') |FROM small_src - """.stripMargin.replaceAll("\n", " ")).collect().head + """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head assert(expected(0) === res(0)) } @@ -541,7 +516,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("transform with SerDe4", """ @@ -550,7 +525,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") @@ -591,6 +566,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Specify the udtf output", "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #1", + "SELECT col FROM (SELECT explode(array(key,value)) FROM src LIMIT 1) t") + + createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #2", + "SELECT key,value FROM (SELECT explode(map(key,value)) FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") @@ -628,26 +609,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") @@ -670,11 +657,62 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |select * where key = 4 """.stripMargin) + // test get_json_object again Hive, because the HiveCompatabilitySuite cannot handle result + // with newline in it. + createQueryTest("get_json_object #1", + "SELECT get_json_object(src_json.json, '$') FROM src_json") + + createQueryTest("get_json_object #2", + "SELECT get_json_object(src_json.json, '$.owner'), get_json_object(src_json.json, '$.store')" + + " FROM src_json") + + createQueryTest("get_json_object #3", + "SELECT get_json_object(src_json.json, '$.store.bicycle'), " + + "get_json_object(src_json.json, '$.store.book') FROM src_json") + + createQueryTest("get_json_object #4", + "SELECT get_json_object(src_json.json, '$.store.book[0]'), " + + "get_json_object(src_json.json, '$.store.book[*]') FROM src_json") + + createQueryTest("get_json_object #5", + "SELECT get_json_object(src_json.json, '$.store.book[0].category'), " + + "get_json_object(src_json.json, '$.store.book[*].category'), " + + "get_json_object(src_json.json, '$.store.book[*].isbn'), " + + "get_json_object(src_json.json, '$.store.book[*].reader') FROM src_json") + + createQueryTest("get_json_object #6", + "SELECT get_json_object(src_json.json, '$.store.book[*].reader[0].age'), " + + "get_json_object(src_json.json, '$.store.book[*].reader[*].age') FROM src_json") + + createQueryTest("get_json_object #7", + "SELECT get_json_object(src_json.json, '$.store.basket[0][1]'), " + + "get_json_object(src_json.json, '$.store.basket[*]'), " + + // Hive returns wrong result with [*][0], so this expression is change to make test pass + "get_json_object(src_json.json, '$.store.basket[0][0]'), " + + "get_json_object(src_json.json, '$.store.basket[0][*]'), " + + "get_json_object(src_json.json, '$.store.basket[*][*]'), " + + "get_json_object(src_json.json, '$.store.basket[0][2].b'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].b') FROM src_json") + + createQueryTest("get_json_object #8", + "SELECT get_json_object(src_json.json, '$.non_exist_key'), " + + "get_json_object(src_json.json, '$..no_recursive'), " + + "get_json_object(src_json.json, '$.store.book[10]'), " + + "get_json_object(src_json.json, '$.store.book[0].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[*].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].non_exist_key') FROM src_json") + + createQueryTest("get_json_object #9", + "SELECT get_json_object(src_json.json, '$.zip code') FROM src_json") + + createQueryTest("get_json_object #10", + "SELECT get_json_object(src_json.json, '$.fb:testid') FROM src_json") + test("predicates contains an empty AttributeSet() references") { sql( """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 ) table + | SELECT 1 AS a FROM src LIMIT 1 ) t |WHERE abs(20141202) is not null """.stripMargin).collect() } @@ -898,7 +936,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -930,6 +968,18 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE t1") } + test("CREATE TEMPORARY FUNCTION") { + val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath + val jarURL = s"file://$funcJar" + sql(s"ADD JAR $jarURL") + sql( + """CREATE TEMPORARY FUNCTION udtf_count2 AS + |'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) + sql("DROP TEMPORARY FUNCTION udtf_count2") + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") @@ -987,7 +1037,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .zip(parts) .map { case (k, v) => if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultStrVal}" } else { s"$k=$v" } @@ -1136,18 +1186,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestHiveContext.overrideConfs.size) + val defaults = collectResults(sql("SET")) assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey=$testVal")) } - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assert(hiveconf.get(testKey, "") === testVal) + assertResult(defaults ++ Set(testKey -> testVal))(collectResults(sql("SET"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + assertResult(defaults ++ Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } @@ -1163,6 +1214,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { conf.clear() } + test("current_database with multiple sessions") { + sql("create database a") + sql("use a") + val s2 = newSession() + s2.sql("create database b") + s2.sql("use b") + + assert(sql("select current_database()").first() === Row("a")) + assert(s2.sql("select current_database()").first() === Row("b")) + + try { + sql("create table test_a(key INT, value STRING)") + s2.sql("create table test_b(key INT, value STRING)") + + sql("select * from test_a") + intercept[AnalysisException] { + sql("select * from test_b") + } + sql("select * from b.test_b") + + s2.sql("select * from test_b") + intercept[AnalysisException] { + s2.sql("select * from test_a") + } + s2.sql("select * from a.test_a") + } finally { + sql("DROP TABLE IF EXISTS test_a") + s2.sql("DROP TABLE IF EXISTS test_b") + } + + } + + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 197e9bfb02c4..5bd323ea096a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -43,7 +43,9 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { + case e: Project => e + }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 7069afc9f7da..9deb1a6db15a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,23 +17,22 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties +import java.io.{PrintWriter, File, DataInput, DataOutput} +import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDFOPAnd, GenericUDTFExplode, GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.TestHive - +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) @@ -46,10 +45,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest { +class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - import TestHive.{udf, sql} - import TestHive.implicits._ + import hiveContext.{udf, sql} + import hiveContext.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -94,53 +93,45 @@ class HiveUDFSuite extends QueryTest { } test("Max/Min on named_struct") { - def testOrderInStruct(): Unit = { - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - - // nested struct cases - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - } - val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) - TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) - testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) - testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) } - test("SPARK-6409 UDAFAverage test") { + test("SPARK-6409 UDAF Average test") { sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() + hiveContext.reset() } test("SPARK-2693 udaf aggregates test") { @@ -160,7 +151,7 @@ class HiveUDFSuite extends QueryTest { } test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") @@ -171,41 +162,41 @@ class HiveUDFSuite extends QueryTest { Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - TestHive.reset() + hiveContext.reset() } test("UDFToListString") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListString(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - TestHive.reset() + hiveContext.reset() } test("UDFToListInt") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListInt(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - TestHive.reset() + hiveContext.reset() } test("UDFToStringIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + @@ -213,15 +204,15 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToStringIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFToIntIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + @@ -229,15 +220,15 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToIntIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() @@ -249,11 +240,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - TestHive.reset() + hiveContext.reset() } test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") @@ -264,11 +255,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - TestHive.reset() + hiveContext.reset() } test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") @@ -276,13 +267,18 @@ class HiveUDFSuite extends QueryTest { checkAnswer( sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) + + checkAnswer( + sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"), + Seq(Row(" hello world"), Row(" hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - TestHive.reset() + hiveContext.reset() } test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: @@ -295,7 +291,151 @@ class HiveUDFSuite extends QueryTest { Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - TestHive.reset() + hiveContext.reset() + } + + test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { + Seq((1, 2)).toDF("a", "b").registerTempTable("testUDF") + + { + // HiveSimpleUDF + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDFTwoListList() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + } + + { + // HiveGenericUDF + sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDFAnd() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") + } + + { + // Hive UDAF + sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") + } + + { + // AbstractGenericUDAFResolver + sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") + } + + { + // Hive UDTF + sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDTFExplode() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") + } + + sqlContext.dropTempTable("testUDF") + } + + test("SPARK-11522 select input_file_name from non-parquet table"){ + + withTempDir { tempDir => + + // EXTERNAL OpenCSVSerde table pointing to LOCATION + + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + WITH SERDEPROPERTIES ( + \"separatorChar\" = \",\", + \"quoteChar\" = \"\\\"\", + \"escapeChar\" = \"\\\\\") + LOCATION '$tempDir' + """) + + val answer1 = + sql("SELECT input_file_name() FROM csv_table").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count() + assert(count1 == 2) + sql("DROP TABLE csv_table") + + // EXTERNAL pointing to LOCATION + + sql( + s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int) + ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + LOCATION '$tempDir' + """) + + val answer2 = + sql("SELECT input_file_name() as file FROM external_t5").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count + assert(count2 == 2) + sql("DROP TABLE external_t5") + } + + withTempDir { tempDir => + + // External parquet pointing to LOCATION + + val parquetLocation = tempDir + "/external_parquet" + sql("SELECT 1, 2").write.parquet(parquetLocation) + + sql( + s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int) + STORED AS PARQUET + LOCATION '$parquetLocation' + """) + + val answer3 = + sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0) + assert(answer3.contains("external_parquet")) + + val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count + assert(count3 == 1) + sql("DROP TABLE external_parquet") + } + + // Non-External parquet pointing to /tmp/... + + sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " + + " STORED AS parquet " + + " AS SELECT 1, 2") + + val answer4 = + sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0) + assert(answer4.contains("parquet_tmp")) + + val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count + assert(count4 == 1) + sql("DROP TABLE parquet_tmp") } } @@ -321,11 +461,11 @@ class PairSerDe extends AbstractSerDe { override def getObjectInspector: ObjectInspector = { ObjectInspectorFactory .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + Arrays.asList("pair"), + Arrays.asList(ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) )) } @@ -338,10 +478,10 @@ class PairSerDe extends AbstractSerDe { override def deserialize(value: Writable): AnyRef = { val pair = value.asInstanceOf[TestPair] - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) + val row = new ArrayList[ArrayList[AnyRef]] + row.add(new ArrayList[AnyRef](2)) + row.get(0).add(Integer.valueOf(pair.entry._1)) + row.get(0).add(Integer.valueOf(pair.entry._2)) row } @@ -350,9 +490,9 @@ class PairSerDe extends AbstractSerDe { class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e32..210d56674541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.TestHive -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * A set of test cases that validate partition and column pruning. */ @@ -82,16 +81,16 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq.empty) createPruningTest("Column pruning - non-trivial top project with aliases", - "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", - Seq("double"), + "SELECT c1 * 2 AS dbl FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("dbl"), Seq("key"), Seq.empty) // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project", - "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", - Seq("double"), + "SELECT key * 2 AS dbl FROM src WHERE value IS NOT NULL", + Seq("dbl"), Seq("key", "value"), Seq.empty) @@ -161,7 +160,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") - val actualPartitions = actualPartValues.map(_.toSeq.mkString(",")).sorted + val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted assert(actualPartitions === expectedPartitions, "Partitions selected do not match") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c4923d83e48f..3427152b2da0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.{TableIdentifier, DefaultParserDialect} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -64,8 +63,28 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive +class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ + + test("UDTF") { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") @@ -141,10 +160,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("show functions") { - val allFunctions = + val allBuiltinFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted - checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + // The TestContext is shared by all the test cases, some functions may be registered before + // this, so we check that all the builtin functions are returned. + val allFunctions = sql("SHOW functions").collect().map(r => r(0)) + allBuiltinFunctions.foreach { f => + assert(allFunctions.contains(f)) + } checkAnswer(sql("SHOW functions abs"), Row("abs")) checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) @@ -242,9 +266,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(TableIdentifier(tableName))) relation match { - case LogicalRelation(r: ParquetRelation) => + case LogicalRelation(r: ParquetRelation, _) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + @@ -264,47 +288,58 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { setConf(HiveContext.CONVERT_CTAS, true) - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - var message = intercept[AnalysisException] { + try { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("ctas1 already exists")) - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - // Specifying database name for query can be converted to data source write path - // is not allowed right now. - message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert( - message.contains("Cannot specify database name in a CTAS statement"), - "When spark.sql.hive.convertCTAS is true, we should not allow " + - "database name specified.") - - sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql( - "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") + } + } - setConf(HiveContext.CONVERT_CTAS, originalConf) + test("SQL dialect at the start of HiveContext") { + val hiveContext = new HiveContext(sqlContext.sparkContext) + val dialectConf = "spark.sql.dialect" + checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) + assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect]) } test("SQL Dialect Switching") { @@ -484,19 +519,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), - Seq.empty[Row]) + + sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") checkAnswer( sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") checkAnswer( sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) @@ -589,7 +624,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val rowRdd = sparkContext.parallelize(row :: Nil) - TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") + hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -637,7 +672,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("resolve udtf in projection #2") { val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { @@ -652,7 +687,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) @@ -670,29 +705,32 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val originalConf = convertCTAS setConf(HiveContext.CONVERT_CTAS, false) - sql("CREATE TABLE explodeTest (key bigInt)") - table("explodeTest").queryExecution.analyzed match { - case metastoreRelation: MetastoreRelation => // OK - case _ => - fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") - } + try { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } - sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") - checkAnswer( - sql("SELECT key from explodeTest"), - (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) - ) + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) - sql("DROP TABLE explodeTest") - dropTempTable("data") - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE explodeTest") + dropTempTable("data") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + } } test("sanity test for SPARK-6618") { (1 to 100).par.map { i => val tableName = s"SPARK_6618_table_$i" sql(s"CREATE TABLE $tableName (col1 string)") - catalog.lookupRelation(Seq(tableName)) + catalog.lookupRelation(TableIdentifier(tableName)) table(tableName) tables() sql(s"DROP TABLE $tableName") @@ -725,6 +763,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .queryExecution.toRdd.count()) } + test("test script transform data type") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("key", "value").registerTempTable("test") + checkAnswer( + sql("""FROM + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin), (2 to 6).map(i => Row(i))) + } + test("window function: udaf with aggregate expressin") { val data = Seq( WindowData(1, "a", 5), @@ -797,6 +845,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { ).map(i => Row(i._1, i._2, i._3))) } + test("window function: refer column in inner select block") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product, 1 as tmp1 from windowData) tmp + """.stripMargin), + Seq( + ("a", 2), + ("a", 3), + ("b", 2), + ("b", 3), + ("c", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + } + test("window function: partition and order expressions") { val data = Seq( WindowData(1, "a", 5), @@ -912,6 +987,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-7595: Window will cause resolve failed with self join") { + sql("SELECT * FROM src") // Force loading of src table. + checkAnswer(sql( """ |with @@ -1004,10 +1081,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - TestHive.sql( - s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + sql( + s"ADD JAR ${hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - TestHive.sql( + sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -1057,24 +1134,349 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = - TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) - df.toDF("id", "date").registerTempTable("test_SPARK8588") + createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( - TestHive.sql( + sql( """ - |select id, concat(year(date)) - |from test_SPARK8588 where concat(year(date), ' year') in ('2015 year', '2014 year') + |select id, concat(year(datef)) + |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) - TestHive.dropTempTable("test_SPARK8588") + dropTempTable("test_SPARK8588") } test("SPARK-9371: fix the support for special chars in column names for hive context") { - TestHive.read.json(TestHive.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } + + test("SPARK-10741: Sort on Aggregate using parquet") { + withTable("test10741") { + withTempTable("src") { + Seq("a" -> 5, "a" -> 9, "b" -> 6).toDF().registerTempTable("src") + sql("CREATE TABLE test10741(c1 STRING, c2 INT) STORED AS PARQUET AS SELECT * FROM src") + } + + checkAnswer(sql( + """ + |SELECT c1, AVG(c2) AS c_avg + |FROM test10741 + |GROUP BY c1 + |HAVING (AVG(c2) > 5) ORDER BY c1 + """.stripMargin), Row("a", 7.0) :: Row("b", 6.0) :: Nil) + + checkAnswer(sql( + """ + |SELECT c1, AVG(c2) AS c_avg + |FROM test10741 + |GROUP BY c1 + |ORDER BY AVG(c2) + """.stripMargin), Row("b", 6.0) :: Row("a", 7.0) :: Nil) + } + } + + test("run sql directly on files") { + val df = sqlContext.range(100) + withTempPath(f => { + df.write.parquet(f.getCanonicalPath) + checkAnswer(sql(s"select id from parquet.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select id from `org.apache.spark.sql.parquet`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select a.id from parquet.`${f.getCanonicalPath}` as a"), + df) + }) + } + + test("correctly parse CREATE VIEW statement") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt") { + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt") + sql( + """CREATE VIEW IF NOT EXISTS + |default.testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') + |COMMENT 'blabla' + |TBLPROPERTIES ('a' = 'b') + |AS SELECT * FROM jt""".stripMargin) + checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i))) + sql("DROP VIEW testView") + } + } + } + + test("correctly handle CREATE VIEW IF NOT EXISTS") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE VIEW IF NOT EXISTS testView AS SELECT * FROM jt2") + + // make sure our view doesn't change. + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + sql("DROP VIEW testView") + } + } + } + + test("correctly handle CREATE OR REPLACE VIEW") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + } + assert(e.message.contains("not allowed to define a view")) + } + } + } + + test("correctly handle ALTER VIEW") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("ALTER VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + } + } + } + + test("create hive view for json table") { + // json table is not hive-compatible, make sure the new flag fix it. + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + sql("DROP VIEW testView") + } + } + } + + test("create hive view for partitioned parquet table") { + // partitioned parquet table is not hive-compatible, make sure the new flag fix it. + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("parTable") { + val df = Seq(1 -> "a").toDF("i", "j") + df.write.format("parquet").partitionBy("i").saveAsTable("parTable") + sql("CREATE VIEW testView AS SELECT i, j FROM parTable") + checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) + sql("DROP VIEW testView") + } + } + } + + test("create hive view for joined tables") { + // make sure the new flag can handle some complex cases like join and schema change. + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { + withTable("jt1", "jt2") { + sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") + sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") + sql("CREATE VIEW testView AS SELECT * FROM jt1 JOIN jt2 ON id1 == id2") + checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) + + val df = (1 until 10).map(i => i -> i).toDF("id1", "newCol") + df.write.format("json").mode(SaveMode.Overwrite).saveAsTable("jt1") + checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + } + } + } + + test("SPARK-10562: partition by column with mixed case name") { + withTable("tbl10562") { + val df = Seq(2012 -> "a").toDF("Year", "val") + df.write.partitionBy("Year").saveAsTable("tbl10562") + checkAnswer(sql("SELECT Year FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT yEAr FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year > 2015"), Nil) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) + } + } + + test("SPARK-11453: append data to partitioned table") { + withTable("tbl11453") { + Seq("1" -> "10", "2" -> "20").toDF("i", "j") + .write.partitionBy("i").saveAsTable("tbl11453") + + Seq("3" -> "30").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) + + // make sure case sensitivity is correct. + Seq("4" -> "40").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) + } + } + + test("SPARK-11590: use native json_tuple in lateral view") { + checkAnswer(sql( + """ + |SELECT a, b + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin), Row("value1", "12")) + + // we should use `c0`, `c1`... as the name of fields if no alias is provided, to follow hive. + checkAnswer(sql( + """ + |SELECT c0, c1 + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt + """.stripMargin), Row("value1", "12")) + + // we can also use `json_tuple` in project list. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2') + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + """.stripMargin), Row("value1", "12")) + + // we can also mix `json_tuple` with other project expressions. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str + |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test + """.stripMargin), Row("value1", "12", 3.14, "hello")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3..7cfdb886b585 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -22,16 +22,14 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest { - - override def sqlContext: SQLContext = TestHive +class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { + import hiveContext.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -40,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) @@ -58,7 +58,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -72,7 +72,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -87,7 +87,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -104,7 +104,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala new file mode 100644 index 000000000000..c05dbfd7608d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.test.SQLTestUtils + +/** + * This suite contains a couple of Hive window tests which fail in the typical setup due to tiny + * numerical differences or due semantic differences between Hive and Spark. + */ +class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + } + + override def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + } + + test("windowing.q -- 15. testExpressions") { + // Moved because: + // - Spark uses a different default stddev (sample instead of pop) + // - Tiny numerical differences in stddev results. + // - Different StdDev behavior when n=1 (NaN instead of 0) + checkAnswer(sql(s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + // scalastyle:on + } + + test("windowing.q -- 20. testSTATs") { + // Moved because: + // - Spark uses a different default stddev/variance (sample instead of pop) + // - Tiny numerical differences in aggregation results. + checkAnswer(sql(""" + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev_pop(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |var_pop(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 2, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 6, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 34, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 2, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 34, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 2, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 6, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 28, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 34, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 2, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 6, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 28, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 34, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 42, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 6, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 28, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 34, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 42, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 6, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 28, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 42, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 2, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 14, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 40, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 2, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 14, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 25, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 40, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 2, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 14, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 18, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 25, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 40, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 2, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 18, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 25, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 40, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 2, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 18, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 25, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 14, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 17, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 19, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 1, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 14, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 17, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 19, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 1, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 14, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 17, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 19, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 45, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 1, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 14, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 19, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 45, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 1, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 19, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 45, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 10, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 27, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 39, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 7, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 10, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 27, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 39, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 7, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 10, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 12, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 27, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 39, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 7, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 12, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 27, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 39, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 7, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 12, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 27, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 2, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 6, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 31, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 2, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 6, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 31, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 46, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 2, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 6, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 23, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 31, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 46, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 2, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 6, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 23, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 46, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 2, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 23, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) + // scalastyle:on + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index af3f468aaa5e..92043d66c914 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,17 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName - import sqlContext._ - import sqlContext.implicits._ + // ORC does not play well with NullType and UDT. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _: UserDefinedType[_] => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -48,11 +55,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + hiveContext.read.options(Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index d463e8fd626f..52e09f9496f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,19 +18,16 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -39,8 +36,11 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal +class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ + + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile @@ -59,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally TestHive.dropTempTable(tableName) + try f finally hiveContext.dropTempTable(tableName) } protected def makePartitionDir( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 744d46293814..2156806d21f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.hive.ql.io.orc.CompressionKind import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -218,7 +219,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(TableIdentifier("tmp")) } test("overwriting") { @@ -228,7 +229,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(TableIdentifier("tmp")) } test("self-join") { @@ -287,6 +288,20 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("SPARK-9170: Don't implicitly lowercase of user-provided columns") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) + sqlContext.read.format("orc").load(path).schema("Acol") + intercept[IllegalArgumentException] { + sqlContext.read.format("orc").load(path).schema("acol") + } + checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + (0 until 10).map(Row(_))) + } + } + test("SPARK-8501: Avoids discovery schema from empty ORC files") { withTempPath { dir => val path = dir.getCanonicalPath @@ -330,4 +345,53 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + val path = dir.getCanonicalPath + + // For field "a", the first column has odds integers. This is to check the filtered count + // when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters rows + // only when all the values are null (maybe this works differently when the data + // or query is complicated). So, simply here a column only having `null` is added. + val data = (0 until 10).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + val nullValue: Option[String] = None + (maybeInt, nullValue) + } + createDataFrame(data).toDF("a", "b").write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { + val sourceDf = stripSparkFilter(df.where(pred)) + val data = sourceDf.collect().toSet + val expectedData = answer.toSet + + // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check + // the number of rows returned from the ORC to make sure our filter pushdown work. + // A tricky part is, ORC does not process filter rows fully but return some possible + // results. So, this checks if the number of result is less than the original count + // of data, and then checks if it contains the expected data. + val isOrcFiltered = sourceDf.count < 10 && expectedData.subsetOf(data) + assert(isOrcFiltered) + } + + checkPredicate('a === 5, List(5).map(Row(_, null))) + checkPredicate('a <=> 5, List(5).map(Row(_, null))) + checkPredicate('a < 5, List(1, 3).map(Row(_, null))) + checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate('a > 5, List(7, 9).map(Row(_, null))) + checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate('a.isNull, List(null).map(Row(_, null))) + checkPredicate('b.isNotNull, List()) + checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) + checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) + checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf4645..7a34cf731b4c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,12 +21,14 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { +abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + var orcTableDir: File = null var orcTableAsDir: File = null @@ -121,13 +123,42 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } + + test("write null values") { + sql("DROP TABLE IF EXISTS orcNullValues") + + val df = sql( + """ + |SELECT + | CAST(null as TINYINT), + | CAST(null as SMALLINT), + | CAST(null as INT), + | CAST(null as BIGINT), + | CAST(null as FLOAT), + | CAST(null as DOUBLE), + | CAST(null as DECIMAL(7,2)), + | CAST(null as TIMESTAMP), + | CAST(null as DATE), + | CAST(null as STRING), + | CAST(null as VARCHAR(10)) + |FROM orc_temp_table limit 1 + """.stripMargin) + + df.write.format("orc").saveAsTable("orcNullValues") + + checkAnswer( + sql("SELECT * FROM orcNullValues"), + Row.fromSeq(Seq.fill(11)(null))) + + sql("DROP TABLE IF EXISTS orcNullValues") + } } class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -135,7 +166,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da0..88a0ed511749 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,15 +22,12 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton -private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - - import sqlContext.implicits._ - import sqlContext.sparkContext +private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { + import testImplicits._ /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index f56fb96c52d3..905eb7a3925b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,16 +19,12 @@ package org.apache.spark.sql.hive import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -58,9 +54,18 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() - + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") sql(s""" create external table partitioned_parquet ( @@ -172,14 +177,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - sql("DROP TABLE partitioned_parquet") - sql("DROP TABLE partitioned_parquet_with_key") - sql("DROP TABLE partitioned_parquet_with_complextypes") - sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") - sql("DROP TABLE normal_parquet") - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS jt_array") - sql("DROP TABLE IF EXISTS test_parquet") + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -203,6 +208,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("insert into an empty parquet table") { + dropTables("test_insert_parquet") sql( """ |create table test_insert_parquet @@ -228,7 +234,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") // Create it again. sql( @@ -255,166 +261,166 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") } test("scan a parquet table created through a CTAS statement") { - sql( - """ - |create table test_parquet_ctas ROW FORMAT - |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - |AS select * from jt - """.stripMargin) + withTable("test_parquet_ctas") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) - checkAnswer( - sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), - Seq(Row(1, "str1")) - ) + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation) => // OK - case _ => fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation].getCanonicalName}") + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(_: ParquetRelation, _) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation].getCanonicalName }") + } } - - sql("DROP TABLE IF EXISTS test_parquet_ctas") } test("MetastoreRelation in InsertIntoTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("MetastoreRelation in InsertIntoHiveTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("SPARK-6450 regression test") { - sql( - """CREATE TABLE IF NOT EXISTS ms_convert (key INT) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("ms_convert") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed - // This shouldn't throw AnalysisException - val analyzed = sql( - """SELECT key FROM ms_convert - |UNION ALL - |SELECT key FROM ms_convert - """.stripMargin).queryExecution.analyzed - - assertResult(2) { - analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation) => r - }.size + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation, _) => r + }.size + } } - - sql("DROP TABLE ms_convert") } def collectParquetRelation(df: DataFrame): ParquetRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation) => r + case LogicalRelation(r: ParquetRelation, _) => r }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$plan") } } test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE nonPartitioned ( - | key INT, - | value STRING - |) - |STORED AS PARQUET - """.stripMargin) - - // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE nonPartitioned") + withTable("nonPartitioned") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE partitioned ( - | key INT, - | value STRING - |) - |PARTITIONED BY (part INT) - |STORED AS PARQUET + withTable("partitioned") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET """.stripMargin) - // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE partitioned") + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("Caching converted data source Parquet Relations") { @@ -422,7 +428,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -430,8 +436,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } - sql("DROP TABLE IF EXISTS test_insert_parquet") - sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") sql( """ @@ -479,7 +484,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | intField INT, | stringField STRING |) - |PARTITIONED BY (date string) + |PARTITIONED BY (`date` string) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' @@ -491,7 +496,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-01') + |PARTITION (`date`='2015-04-01') |select a, b from jt """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. @@ -500,7 +505,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-02') + |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) @@ -510,7 +515,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), sql( """ |select b, '2015-04-01', a FROM jt @@ -521,8 +526,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { invalidateTable("test_parquet_partitioned_cache_test") assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql("DROP TABLE test_insert_parquet") - sql("DROP TABLE test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } } @@ -530,8 +534,16 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { + import testImplicits._ + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") sql( s""" create temporary table partitioned_parquet @@ -608,7 +620,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { val conf = Seq( HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") withSQLConf(conf: _*) { sql( @@ -635,22 +647,22 @@ class ParquetSourceSuite extends ParquetPartitioningTest { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.write.format("parquet").saveAsTable("alwaysNullable") + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val arrayType2 = ArrayType(IntegerType, containsNull = true) - val expectedSchema2 = - StructType( - StructField("m", mapType2, nullable = true) :: - StructField("a", arrayType2, nullable = true) :: Nil) + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) - assert(table("alwaysNullable").schema === expectedSchema2) - - checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), - Row(Map(2 -> 3), Seq(4, 5, 6))) + assert(table("alwaysNullable").schema === expectedSchema2) - sql("DROP TABLE alwaysNullable") + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } } test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { @@ -673,8 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ var partitionedTableDir: File = null var normalTableDir: File = null @@ -738,6 +750,16 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with partitionedTableDirWithKeyAndComplexTypes.delete() } + /** + * Drop named tables if they exist + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + Seq( "partitioned_parquet", "partitioned_parquet_with_key", @@ -847,8 +869,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with (1 to 10).map(i => Row(1, i, f"${i}_string"))) } - // Re-enable this after SPARK-5508 is fixed - ignore(s"SPARK-5775 read array from $table") { + test(s"SPARK-5775 read array from $table") { checkAnswer( sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), (1 to 10).map(i => Row(1 to i, 1))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706..dc0531a6d4bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive +class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..ef37787137d0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.math.BigDecimal + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + // JSON does not write data of NullType and does not play well with BinaryType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: BinaryType => false + case _: CalendarIntervalType => false + case _ => true + } + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + hiveContext.read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } + + test("SPARK-10196: save decimal type to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("decimal", DecimalType(7, 2)) + + val data = + Row(new BigDecimal("10.02")) :: + Row(new BigDecimal("20000.99")) :: + Row(new BigDecimal("10000")) :: Nil + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d..e866493ee6c9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,15 +23,21 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + import testImplicits._ - import sqlContext._ - import sqlContext.implicits._ + override val dataSourceName: String = "parquet" + + // Parquet does not play well with NullType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -51,7 +57,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -69,7 +75,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -97,7 +103,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) + checkAnswer(hiveContext.read.format("parquet").load(path), df) } } @@ -107,7 +113,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - range(1, 10) + hiveContext.range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") @@ -125,7 +131,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val summaryPath = new Path(path, "_metadata") val commonSummaryPath = new Path(path, "_common_metadata") - val fs = summaryPath.getFileSystem(configuration) + val fs = summaryPath.getFileSystem(hadoopConfiguration) fs.delete(summaryPath, true) fs.delete(commonSummaryPath, true) @@ -136,4 +142,36 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { assert(fs.exists(commonSummaryPath)) } } + + test("SPARK-10334 Projections and filters should be kept in physical plan") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + val physicalPlan = df.queryExecution.executedPlan + + assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) + assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) + } + } + + test("SPARK-11500: Not deterministic order of columns when using merging schemas.") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/part=1" + Seq(1, 1).zipWithIndex.toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/part=2" + Seq(1, 1).zipWithIndex.toDF("c", "b").write.parquet(pathTwo) + val pathThree = s"${dir.getCanonicalPath}/part=3" + Seq(1, 1).zipWithIndex.toDF("d", "b").write.parquet(pathThree) + + // The schema consists of the leading columns of the first part-file + // in the lexicographic order. + assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name) + === Seq("a", "b", "c", "d", "part")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index e8975e5f5cd0..b554d135e4b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -17,15 +17,41 @@ package org.apache.spark.sql.sources +import java.io.File + import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper} +import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, Row, execution} +import org.apache.spark.util.Utils + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { + import testImplicits._ -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - import sqlContext._ + // We have a very limited number of supported types at here since it is just for a + // test relation and we do very basic testing at here. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: BinaryType => false + // We are using random data generator and the generated strings are not really valid string. + case _: StringType => false + case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442 + case _: CalendarIntervalType => false + case _: DateType => false + case _: TimestampType => false + case _: ArrayType => false + case _: MapType => false + case _: StructType => false + case _: UserDefinedType[_] => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -44,9 +70,314 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } } + + private var tempPath: File = _ + + private var partitionedDF: DataFrame = _ + + private val partitionedDataSchema: StructType = + new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("c", StringType) + + protected override def beforeAll(): Unit = { + this.tempPath = Utils.createTempDir() + + val df = sqlContext.range(10).select( + 'id cast IntegerType as 'a, + ('id cast IntegerType) * 2 as 'b, + concat(lit("val_"), 'id) as 'c + ) + + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=0") + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=1") + + partitionedDF = partitionedReader.load(tempPath.getCanonicalPath) + } + + override protected def afterAll(): Unit = { + Utils.deleteRecursively(tempPath) + } + + private def partitionedWriter(df: DataFrame) = + df.write.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + private def partitionedReader = + sqlContext.read.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + /** + * Constructs test cases that test column pruning and filter push-down. + * + * For filter push-down, the following filters are not pushed-down. + * + * 1. Partitioning filters don't participate filter push-down, they are handled separately in + * `DataSourceStrategy` + * + * 2. Catalyst filter `Expression`s that cannot be converted to data source `Filter`s are not + * pushed down (e.g. UDF and filters referencing multiple columns). + * + * 3. Catalyst filter `Expression`s that can be converted to data source `Filter`s but cannot be + * handled by the underlying data source are not pushed down (e.g. returned from + * `BaseRelation.unhandledFilters()`). + * + * Note that for [[SimpleTextRelation]], all data source [[Filter]]s other than [[GreaterThan]] + * are unhandled. We made this assumption in [[SimpleTextRelation.unhandledFilters()]] only + * for testing purposes. + * + * @param projections Projection list of the query + * @param filter Filter condition of the query + * @param requiredColumns Expected names of required columns + * @param pushedFilters Expected data source [[Filter]]s that are pushed down + * @param inconvertibleFilters Expected Catalyst filter [[Expression]]s that cannot be converted + * to data source [[Filter]]s + * @param unhandledFilters Expected Catalyst flter [[Expression]]s that can be converted to data + * source [[Filter]]s but cannot be handled by the data source relation + * @param partitioningFilters Expected Catalyst filter [[Expression]]s that reference partition + * columns + * @param expectedRawScanAnswer Expected query result of the raw table scan returned by the data + * source relation + * @param expectedAnswer Expected query result of the full query + */ + def testPruningAndFiltering( + projections: Seq[Column], + filter: Column, + requiredColumns: Seq[String], + pushedFilters: Seq[Filter], + inconvertibleFilters: Seq[Column], + unhandledFilters: Seq[Column], + partitioningFilters: Seq[Column])( + expectedRawScanAnswer: => Seq[Row])( + expectedAnswer: => Seq[Row]): Unit = { + test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") { + val df = partitionedDF.where(filter).select(projections: _*) + val queryExecution = df.queryExecution + val executedPlan = queryExecution.executedPlan + + val rawScan = executedPlan.collect { + case p: PhysicalRDD => p + } match { + case Seq(scan) => scan + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + + markup("Checking raw scan answer") + checkAnswer( + DataFrame(sqlContext, LogicalRDD(rawScan.output, rawScan.rdd)(sqlContext)), + expectedRawScanAnswer) + + markup("Checking full query answer") + checkAnswer(df, expectedAnswer) + + markup("Checking required columns") + assert(requiredColumns === SimpleTextRelation.requiredColumns) + + val nonPushedFilters = { + val boundFilters = executedPlan.collect { + case f: execution.Filter => f + } match { + case Nil => Nil + case Seq(f) => splitConjunctivePredicates(f.condition) + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + + // Unbound these bound filters so that we can easily compare them with expected results. + boundFilters.map { + _.transform { case a: AttributeReference => UnresolvedAttribute(a.name) } + }.toSet + } + + markup("Checking pushed filters") + assert(SimpleTextRelation.pushedFilters === pushedFilters.toSet) + + val expectedInconvertibleFilters = inconvertibleFilters.map(_.expr).toSet + val expectedUnhandledFilters = unhandledFilters.map(_.expr).toSet + val expectedPartitioningFilters = partitioningFilters.map(_.expr).toSet + + markup("Checking unhandled and inconvertible filters") + assert(expectedInconvertibleFilters ++ expectedUnhandledFilters === nonPushedFilters) + + markup("Checking partitioning filters") + val actualPartitioningFilters = splitConjunctivePredicates(filter.expr).filter { + _.references.contains(UnresolvedAttribute("p")) + }.toSet + + // Partitioning filters are handled separately and don't participate filter push-down. So they + // shouldn't be part of non-pushed filters. + assert(expectedPartitioningFilters.intersect(nonPushedFilters).isEmpty) + assert(expectedPartitioningFilters === actualPartitioningFilters) + } + } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'p > 0, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Nil, + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('c, 'p), + filter = 'a < 3 && 'p > 0, + requiredColumns = Seq("c", "a"), + pushedFilters = Seq(LessThan("a", 3)), + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 3), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row("val_0", 1, 0), + Row("val_1", 1, 1), + Row("val_2", 1, 2), + Row("val_3", 1, 3), + Row("val_4", 1, 4), + Row("val_5", 1, 5), + Row("val_6", 1, 6), + Row("val_7", 1, 7), + Row("val_8", 1, 8), + Row("val_9", 1, 9)) + } { + Seq( + Row("val_0", 1), + Row("val_1", 1), + Row("val_2", 1)) + } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'a > 8, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(18, 0), + Row(18, 1)) + } { + Seq( + Row(18, 0), + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(18, 1)) + } { + Seq( + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("c", "val_7"), LessThan("b", 18)), + inconvertibleFilters = Nil, + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1), + Row(18, 1)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("c", "val_7"), LessThan("b", 18)), + inconvertibleFilters = Seq('a % 2 === 0), + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1, 8), + Row(18, 1, 9)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 7 && 'a < 9, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("a", 7), LessThan("a", 9)), + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 9), + partitioningFilters = Nil + ) { + Seq( + Row(16, 0, 8), + Row(16, 1, 8), + Row(18, 0, 9), + Row(18, 1, 9)) + } { + Seq( + Row(16, 0), + Row(16, 1)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index e8141923a9b5..01960fd2901b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources import java.text.NumberFormat -import java.util.UUID import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} @@ -26,11 +25,12 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. @@ -53,8 +53,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } @@ -65,7 +67,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { - val serialized = row.toSeq.map(_.toString).mkString(",") + val serialized = row.toSeq.map { v => + if (v == null) "" else v.toString + }.mkString(",") recordWriter.write(null, new Text(serialized)) } @@ -85,7 +89,7 @@ class SimpleTextRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation { + extends HadoopFsRelation(parameters) { import sqlContext.sparkContext @@ -109,7 +113,8 @@ class SimpleTextRelation( val fields = dataSchema.map(_.dataType) sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => - Row(record.split(",").zip(fields).map { case (value, dataType) => + Row(record.split(",", -1).zip(fields).map { case (v, dataType) => + val value = if (v == "") null else v // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) val catalystValue = Cast(Literal(value), dataType).eval() // Here we're converting Catalyst values to Scala values to test `needsConversion` @@ -118,6 +123,56 @@ class SimpleTextRelation( } } + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus]): RDD[Row] = { + + SimpleTextRelation.requiredColumns = requiredColumns + SimpleTextRelation.pushedFilters = filters.toSet + + val fields = this.dataSchema.map(_.dataType) + val inputAttributes = this.dataSchema.toAttributes + val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name)) + val dataSchema = this.dataSchema + + val inputPaths = inputFiles.map(_.getPath).mkString(",") + sparkContext.textFile(inputPaths).mapPartitions { iterator => + // Constructs a filter predicate to simulate filter push-down + val predicate = { + val filterCondition: Expression = filters.collect { + // According to `unhandledFilters`, `SimpleTextRelation` only handles `GreaterThan` filter + case sources.GreaterThan(column, value) => + val dataType = dataSchema(column).dataType + val literal = Literal.create(value, dataType) + val attribute = inputAttributes.find(_.name == column).get + expressions.GreaterThan(attribute, literal) + }.reduceOption(expressions.And).getOrElse(Literal(true)) + InterpretedPredicate.create(filterCondition, inputAttributes) + } + + // Uses a simple projection to simulate column pruning + val projection = new InterpretedMutableProjection(outputAttributes, inputAttributes) + val toScala = { + val requiredSchema = StructType.fromAttributes(outputAttributes) + CatalystTypeConverters.createToScalaConverter(requiredSchema) + } + + iterator.map { record => + new GenericInternalRow(record.split(",", -1).zip(fields).map { + case (v, dataType) => + val value = if (v == "") null else v + // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) + Cast(Literal(value), dataType).eval() + }) + }.filter { row => + predicate(row) + }.map { row => + toScala(projection(row)).asInstanceOf[Row] + } + } + } + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) @@ -128,6 +183,23 @@ class SimpleTextRelation( new SimpleTextOutputWriter(path, context) } } + + // `SimpleTextRelation` only handles `GreaterThan` filter. This is used to test filter push-down + // and `BaseRelation.unhandledFilters()`. + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + filters.filter { + case _: GreaterThan => false + case _ => true + } + } +} + +object SimpleTextRelation { + // Used to test column pruning + var requiredColumns: Seq[String] = Nil + + // Used to test filter push-down + var pushedFilters: Set[Filter] = Set.empty } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index dd274023a1cf..665e87e3e335 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -27,20 +27,20 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - - import sqlContext.sql +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { import sqlContext.implicits._ val dataSourceName: String + protected def supportsDataType(dataType: DataType): Boolean = true + val dataSchema = StructType( Seq( @@ -101,6 +101,60 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) + + val loadedDF = sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } + } + } + test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) @@ -299,6 +353,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) @@ -419,6 +486,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) + .option("basePath", file.getCanonicalPath) .load(s"${file.getCanonicalPath}/p1=*/p2=???") val expectedPaths = Set( @@ -433,7 +501,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation) => + case LogicalRelation(relation: HadoopFsRelation, _) => relation.paths.toSet }.getOrElse { fail("Expect an FSBasedRelation, but none could be found") @@ -444,19 +512,39 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } - test("Partition column type casting") { + test("SPARK-9735 Partition column type casting") { withTempPath { file => - val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) - - input - .write - .format(dataSourceName) - .mode(SaveMode.Overwrite) - .partitionBy("ps", "p2") - .saveAsTable("t") + val df = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 1.0d, p2, 123, 123.123f)).toDF("a", "b", "p1", "p2", "p3", "f") + + val input = df.select( + 'a, + 'b, + 'p1.cast(StringType).as('ps1), + 'p2, + 'p3.cast(FloatType).as('pf1), + 'f) withTempTable("t") { - checkAnswer(sqlContext.table("t"), input.collect()) + input + .write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("ps1", "p2", "pf1", "f") + .saveAsTable("t") + + input + .write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("ps1", "p2", "pf1", "f") + .saveAsTable("t") + + val realData = input.collect() + + checkAnswer(sqlContext.table("t"), realData ++ realData) } } } @@ -503,17 +591,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } test("SPARK-8578 specified custom output committer will not be used to append data") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) try { val df = sqlContext.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there data already exists, // this append should succeed because we will use the output committer associated @@ -532,12 +620,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } withTempPath { dir => - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there is no existing data, // this append will fail because AlwaysFailOutputCommitter is used when we do append @@ -548,8 +636,87 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") + withTempDir { file => + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) + } + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") + } + } + + test("SPARK-9899 Disable customized output committer when speculation is on") { + val clonedConf = new Configuration(hadoopConfiguration) + val speculationEnabled = + sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + try { + withTempPath { dir => + // Enables task speculation + sqlContext.sparkContext.conf.set("spark.speculation", "true") + + // Uses a customized output committer which always fails + hadoopConfiguration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + + // Code below shouldn't throw since customized output committer should be disabled. + val df = sqlContext.range(10).coalesce(1) + df.write.format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) + } + } + + test("HadoopFsRelation produces UnsafeRow") { + withTempTable("test_unsafe") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.format(dataSourceName).save(path) + sqlContext.read + .format(dataSourceName) + .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) + .load(path) + .registerTempTable("test_unsafe") + + val df = sqlContext.sql( + """SELECT COUNT(*) + |FROM test_unsafe a JOIN test_unsafe b + |WHERE a.id = b.id + """.stripMargin) + + val plan = df.queryExecution.executedPlan + + assert( + plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, + s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): + |$plan + """.stripMargin) + + checkAnswer(df, Row(3)) + } } } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5..435e16db13ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -47,6 +47,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + @@ -84,19 +88,14 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - com.novocode - junit-interface + org.mockito + mockito-core test diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 8c0fdfa9c747..2803cad8095d 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -21,6 +21,8 @@ import java.util.Iterator; /** + * :: DeveloperApi :: + * * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming * to save the received data (by receivers) and associated metadata to a reliable storage, so that * they can be recovered after driver failures. See the Spark documentation for more information @@ -35,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java index 02324189b782..662889e779fb 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -18,6 +18,8 @@ package org.apache.spark.streaming.util; /** + * :: DeveloperApi :: + * * This abstract class represents a handle that refers to a record written in a * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. * It must contain all the information necessary for the record to be read and returned by diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 4886b68eeaf7..f82323a1cdd9 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -154,34 +154,40 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { var lastClickedBatch = null; var lastTimeout = null; + function isFailedBatch(batchTime) { + return $("#batch-" + batchTime).attr("isFailed") == "true"; + } + // Add points to the line. However, we make it invisible at first. But when the user moves mouse // over a point, it will be displayed with its detail. svg.selectAll(".point") .data(data) .enter().append("circle") - .attr("stroke", "white") // white and opacity = 0 make it invisible - .attr("fill", "white") - .attr("opacity", "0") + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) // white and opacity = 0 make it invisible + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) .style("cursor", "pointer") .attr("cx", function(d) { return x(d.x); }) .attr("cy", function(d) { return y(d.y); }) - .attr("r", function(d) { return 3; }) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}) .on('mouseover', function(d) { var tip = formatYValue(d.y) + " " + unitY + " at " + timeFormat[d.x]; showBootstrapTooltip(d3.select(this).node(), tip); // show the point d3.select(this) - .attr("stroke", "steelblue") - .attr("fill", "steelblue") - .attr("opacity", "1"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("opacity", "1") + .attr("r", "3"); }) .on('mouseout', function() { hideBootstrapTooltip(d3.select(this).node()); // hide the point d3.select(this) - .attr("stroke", "white") - .attr("fill", "white") - .attr("opacity", "0"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}); }) .on("click", function(d) { if (lastTimeout != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2780d5b6adbc..d0046afdeb44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) +class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName @@ -49,11 +49,14 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) // Reload properties for the checkpoint application since user wants to set a reload property // or spark had changed its value and user wants to set it back. val propertiesToReload = List( + "spark.yarn.app.id", + "spark.yarn.app.attemptId", "spark.driver.host", "spark.driver.port", "spark.master", "spark.yarn.keytab", - "spark.yarn.principal") + "spark.yarn.principal", + "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") @@ -64,6 +67,16 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) newSparkConf.set(prop, value) } } + + // Add Yarn proxy filter specific configurations to the recovered SparkConf + val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val filterPrefix = s"spark.$filter.param." + newReloadConf.getAll.foreach { case (k, v) => + if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) { + newSparkConf.set(k, v) + } + } + newSparkConf } @@ -174,16 +187,30 @@ class CheckpointWriter( private var stopped = false private var fs_ : FileSystem = _ + @volatile private var latestCheckpointTime: Time = null + class CheckpointWriteHandler( checkpointTime: Time, bytes: Array[Byte], clearCheckpointDataLater: Boolean) extends Runnable { def run() { + if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) { + latestCheckpointTime = checkpointTime + } var attempts = 0 val startTime = System.currentTimeMillis() val tempFile = new Path(checkpointDir, "temp") - val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) - val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + // We will do checkpoint when generating a batch and completing a batch. When the processing + // time of a batch is greater than the batch interval, checkpointing for completing an old + // batch may run after checkpointing of a new batch. If this happens, checkpoint of an old + // batch actually has the latest information, so we want to recovery from it. Therefore, we + // also use the latest checkpoint time as the file name, so that we can recovery from the + // latest checkpoint file. + // + // Note: there is only one thread writting the checkpoint files, so we don't need to worry + // about thread-safety. + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, latestCheckpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, latestCheckpointTime) while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 @@ -192,7 +219,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -203,7 +232,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } @@ -246,7 +277,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -282,6 +313,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None. + */ + def read(checkpointDir: String): Option[Checkpoint] = { + read(checkpointDir, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = true) + } + /** * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint * files, then return None, else try to return the latest valid checkpoint object. If no @@ -306,7 +346,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) - val compressionCodec = CompressionCodec.createCodec(conf) + var readError: Exception = null checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { @@ -317,13 +357,15 @@ object CheckpointReader extends Logging { return Some(cp) } catch { case e: Exception => + readError = e logWarning("Error reading checkpoint from file " + file, e) } }) // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + throw new SparkException( + s"Failed to read checkpoint from directory $checkpointPath", readError) } None } @@ -335,7 +377,9 @@ class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoade override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { - return loader.loadClass(desc.getName()) + // scalastyle:off classforname + return Class.forName(desc.getName(), false, loader) + // scalastyle:on classforname } catch { case e: Exception => } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 40789c66f399..7829f5e88799 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -38,9 +38,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def start(time: Time) { this.synchronized { - if (zeroTime != null) { - throw new Exception("DStream graph computation already started") - } + require(zeroTime == null, "DStream graph computation already started") zeroTime = time startTime = time outputStreams.foreach(_.initialize(zeroTime)) @@ -68,20 +66,16 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def setBatchDuration(duration: Duration) { this.synchronized { - if (batchDuration != null) { - throw new Exception("Batch duration already set as " + batchDuration + - ". cannot set it again.") - } + require(batchDuration == null, + s"Batch duration already set as $batchDuration. Cannot set it again.") batchDuration = duration } } def remember(duration: Duration) { this.synchronized { - if (rememberDuration != null) { - throw new Exception("Remember duration already set as " + batchDuration + - ". cannot set it again.") - } + require(rememberDuration == null, + s"Remember duration already set as $rememberDuration. Cannot set it again.") rememberDuration = duration } } @@ -117,7 +111,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) val jobs = this.synchronized { - outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + outputStreams.flatMap { outputStream => + val jobOption = outputStream.generateJob(time) + jobOption.foreach(_.setCallSite(outputStream.creationSite)) + jobOption + } } logDebug("Generated " + jobs.length + " jobs for time " + time) jobs @@ -169,7 +167,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { * safe remember duration which can be used to perform cleanup operations. */ def getMaxInputStreamRememberDuration(): Duration = { - inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + // If an InputDStream is not used, its `rememberDuration` will be null and we can ignore them + inputStreams.map(_.rememberDuration).filter(_ != null).maxBy(_.milliseconds) } @throws(classOf[IOException]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 000000000000..42424d67d883 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Abstract class for getting and updating the state in mapping function used in the `mapWithState` + * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) + * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Scala example of using `State`: + * {{{ + * // A mapping function that maintains an integer state and returns a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `State`: + * {{{ + * // A mapping function that maintains an integer state and returns a String + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * + * @Override + * public String call(String key, Optional value, State state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; + * }}} + * + * @tparam S Class of the state + */ +@Experimental +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. + * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. + */ + def get(): S + + /** + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed + */ + def update(newState: S): Unit + + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ + def remove(): Unit + + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeout can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ + def isTimingOut(): Boolean + + /** + * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + */ + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("") + } +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = false + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + override def exists(): Boolean = { + defined + } + + override def get(): S = { + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } + } + + override def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + defined = true + updated = true + } + + override def isTimingOut(): Boolean = { + timingOut + } + + override def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Update the internal data and flags in `this` to the given state option. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Update the internal data and flags in `this` to the given state that is going to be timed out. + * This method allows `this` object to be reused across many state records. + */ + def wrapTimingOutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala new file mode 100644 index 000000000000..9f6f95223f61 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import com.google.common.base.Optional +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.ClosureCleaner +import org.apache.spark.{HashPartitioner, Partitioner} + +/** + * :: Experimental :: + * Abstract class representing all the specifications of the DStream transformation + * `mapWithState` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or + * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of + * this class. + * + * Example in Scala: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * Example in Java: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + * }}} + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped elements + */ +@Experimental +sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable { + + /** Set the RDD containing the initial states that will be used by `mapWithState` */ + def initialState(rdd: RDD[(KeyType, StateType)]): this.type + + /** Set the RDD containing the initial states that will be used by `mapWithState` */ + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type + + /** + * Set the number of partitions by which the state RDDs generated by `mapWithState` + * will be partitioned. Hash partitioning will be used. + */ + def numPartitions(numPartitions: Int): this.type + + /** + * Set the partitioner by which the state RDDs generated by `mapWithState` will be + * be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The + * mapping function will be called one final time on the idle states that are going to be + * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set + * to `true` in that call. + */ + def timeout(idleDuration: Duration): this.type +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation `mapWithState` + * that is used for specifying the parameters of the DStream transformation + * `mapWithState` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Example in Scala: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * Example in Java: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} + */ +@Experimental +object StateSpec { + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType] + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + new StateSpecImpl(mappingFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + val wrappedFunction = + (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => { + Some(mappingFunction(key, value, state)) + } + new StateSpecImpl(wrappedFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all + * the specifications of the `mapWithState` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType](mappingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + Option(t.orNull) + } + StateSpec.function(wrappedFunc) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { + mappingFunction.call(k, Optional.fromNullable(v.get), s) + } + StateSpec.function(wrappedFunc) + } +} + + +/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ +private[streaming] +case class StateSpecImpl[K, V, S, T]( + function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + override def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + override def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + override def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + override def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 177e710ace54..b24c0d067bb0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -200,6 +200,8 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + private[streaming] def getStartSite(): CallSite = startSite.get() + private var shutdownHookRef: AnyRef = _ conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) @@ -443,8 +445,6 @@ class StreamingContext private[streaming] ( } /** - * :: Experimental :: - * * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as flat binary files, assuming a fixed length per record, * generating one byte array per record. Files must be written to the monitored directory @@ -457,7 +457,6 @@ class StreamingContext private[streaming] ( * @param directory HDFS directory to monitor for new file * @param recordLength length of each record in bytes */ - @Experimental def binaryRecordsStream( directory: String, recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { @@ -562,17 +561,25 @@ class StreamingContext private[streaming] ( ) } } + + if (Utils.isDynamicAllocationEnabled(sc.conf)) { + logWarning("Dynamic Allocation is enabled for this application. " + + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + + "Write Ahead Log is not enabled for non-replayable sources like Flume. " + + "See the programming guide for details on how to enable the Write Ahead Log") + } } /** * :: DeveloperApi :: * * Return the current state of the context. The context can be in three possible states - - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. - * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. - * Input DStreams, transformations and output operations cannot be created on the context. - * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + * + * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @DeveloperApi def getState(): StreamingContextState = synchronized { @@ -588,12 +595,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -604,7 +619,7 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) @@ -618,6 +633,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -676,14 +692,28 @@ class StreamingContext private[streaming] ( * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - try { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { + var shutdownHookRefToRemove: AnyRef = null + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop StreamingContext within listener thread of" + + " AsynchronousListenerBus") + } + synchronized { + // The state should always be Stopped after calling `stop()`, even if we haven't started yet state match { case INITIALIZED => logWarning("StreamingContext has not been started yet") + state = STOPPED case STOPPED => logWarning("StreamingContext has already been stopped") + state = STOPPED case ACTIVE => + // It's important that we don't set state = STOPPED until the very end of this case, + // since we need to ensure that we're still able to call `stop()` to recover from + // a partially-stopped StreamingContext which resulted from this `stop()` call being + // interrupted. See SPARK-12001 for more details. Because the body of this case can be + // executed twice in the case of a partial stop, all methods called here need to be + // idempotent. scheduler.stop(stopGracefully) // Removing the streamingSource to de-register the metrics on stop() env.metricsSystem.removeSource(streamingSource) @@ -691,17 +721,19 @@ class StreamingContext private[streaming] ( StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null } logInfo("StreamingContext stopped successfully") + state = STOPPED } - // Even if we have already stopped, we still need to attempt to stop the SparkContext because - // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). - if (stopSparkContext) sc.stop() - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED } + if (shutdownHookRefToRemove != null) { + ShutdownHookManager.removeShutdownHook(shutdownHookRefToRemove) + } + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). + if (stopSparkContext) sc.stop() } private def stopOnShutdown(): Unit = { @@ -725,7 +757,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) @@ -735,7 +767,7 @@ object StreamingContext extends Logging { throw new IllegalStateException( "Only one StreamingContext may be started in this JVM. " + "Currently running StreamingContext was started at" + - activeContext.get.startSite.get.longForm) + activeContext.get.getStartSite().longForm) } } } @@ -860,12 +892,13 @@ object StreamingContext extends Logging { } private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { - if (prefix == null) { - time.milliseconds.toString - } else if (suffix == null || suffix.length ==0) { - prefix + "-" + time.milliseconds - } else { - prefix + "-" + time.milliseconds + "." + suffix + var result = time.milliseconds.toString + if (prefix != null && prefix.length > 0) { + result = s"$prefix-$result" + } + if (suffix != null && suffix.length > 0) { + result = s"$result.$suffix" } + result } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 808dcc174cf9..84acec7d8e33 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,18 +17,17 @@ package org.apache.spark.streaming.api.java -import java.util import java.lang.{Long => JLong} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -145,8 +144,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * an array. */ def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - + new JavaDStream(dstream.glom().map(_.toSeq.asJava)) /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ @@ -191,7 +189,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -204,7 +202,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -282,7 +280,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + dstream.slice(fromTime, toTime).map(wrapRDD).asJava } /** @@ -291,7 +289,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction[R, Void]) { foreachRDD(foreachFunc) } @@ -302,7 +300,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction2[R, Time, Void]) { foreachRDD(foreachFunc) } @@ -310,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -318,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala new file mode 100644 index 000000000000..16c0d6fff822 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.streaming.dstream.MapWithStateDStream + +/** + * :: Experimental :: + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[JavaPairDStream]]. Additionally, it also gives access to the + * stream of state snapshots, that is, the state data of all keys after a batch has updated them. + * + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data + */ +@Experimental +class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming]( + dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType]) + extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) { + + def stateSnapshots(): JavaPairDStream[KeyType, StateType] = + new JavaPairDStream(dstream.stateSnapshots())( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 26383e420101..42ddd63f0f06 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.api.java import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -28,8 +28,10 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -116,14 +118,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey().mapValues(asJavaIterable _) + dstream.groupByKey().mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) + dstream.groupByKey(numPartitions).mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -132,7 +134,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(partitioner).mapValues(asJavaIterable _) + dstream.groupByKey(partitioner).mapValues(_.asJava) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -197,7 +199,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration).mapValues(_.asJava) } /** @@ -212,7 +214,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(_.asJava) } /** @@ -228,8 +230,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(_.asJava) } /** @@ -248,8 +249,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( slideDuration: Duration, partitioner: Partitioner ): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(_.asJava) } /** @@ -428,10 +428,46 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } + /** + * :: Experimental :: + * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. + * + * Example of using `mapWithState`: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data + */ + @Experimental + def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]): + JavaMapWithStateDStream[K, V, StateType, MappedType] = { + new JavaMapWithStateDStream(dstream.mapWithState(spec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) + } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values + val list: JList[V] = values.asJava val scalaState: Optional[S] = JavaUtils.optionToOptional(state) val result: Optional[S] = in.apply(list, scalaState) result.isPresent match { @@ -539,7 +575,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -551,8 +587,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions: Int ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, numPartitions).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -564,8 +599,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, partitioner).mapValues(t => (t._1.asJava, t._2.asJava)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 35cc3ce5cf46..8f21c79a760c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -115,7 +115,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + this(new StreamingContext( + master, + appName, + batchDuration, + sparkHome, + jars, + environment.asScala)) /** * Create a JavaStreamingContext using an existing JavaSparkContext. @@ -197,7 +203,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).iterator().asScala implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -216,8 +222,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { } /** - * :: Experimental :: - * * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as flat binary files with fixed record lengths, * yielding byte arrays @@ -228,7 +232,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param directory HDFS directory to monitor for new files * @param recordLength The length at which to split the records */ - @Experimental def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { ssc.binaryRecordsStream(directory, recordLength) } @@ -432,7 +435,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue) } @@ -456,7 +459,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime) } @@ -481,7 +484,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } @@ -500,7 +503,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[T] = first.classTag ssc.union(dstreams)(cm) } @@ -512,7 +515,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { first: JavaPairDStream[K, V], rest: JList[JavaPairDStream[K, V]] ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[(K, V)] = first.classTag implicit val kcm: ClassTag[K] = first.kManifest implicit val vcm: ClassTag[V] = first.vManifest @@ -534,12 +537,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ): JavaDStream[T] = { implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -559,12 +561,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala new file mode 100644 index 000000000000..7bfd6bd5af75 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler.StreamingListener + +private[streaming] trait PythonStreamingListener{ + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted) { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError) { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped) { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted) { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted) { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted) { } +} + +private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamingListener) + extends JavaStreamingListener { + + /** Called when a receiver has been started */ + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + listener.onReceiverStarted(receiverStarted) + } + + /** Called when a receiver has reported an error */ + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + listener.onReceiverError(receiverError) + } + + /** Called when a receiver has been stopped */ + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + listener.onReceiverStopped(receiverStopped) + } + + /** Called when a batch of jobs has been submitted for processing. */ + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + listener.onBatchSubmitted(batchSubmitted) + } + + /** Called when processing of a batch of jobs has started. */ + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + listener.onBatchStarted(batchStarted) + } + + /** Called when processing of a batch of jobs has completed. */ + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + listener.onBatchCompleted(batchCompleted) + } + + /** Called when processing of a job of a batch has started. */ + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + listener.onOutputOperationStarted(outputOperationStarted) + } + + /** Called when processing of a job of a batch has completed. */ + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + listener.onOutputOperationCompleted(outputOperationCompleted) + } +} + +/** + * A listener interface for receiving information about an ongoing streaming computation. + */ +private[streaming] class JavaStreamingListener { + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { } +} + +/** + * Base trait for events related to JavaStreamingListener + */ +private[streaming] sealed trait JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchCompleted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchStarted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationStarted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationCompleted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStarted(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverError(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStopped(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +/** + * Class having information on batches. + * + * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info + * @param submissionTime Clock time of when jobs of this batch was submitted to the streaming + * scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing. + * `-1` means the batch has not yet started + * @param processingEndTime Clock time of when the last job of this batch finished processing. `-1` + * means the batch has not yet completed. + * @param schedulingDelay Time taken for the first job of this batch to start processing from the + * time this batch was submitted to the streaming scheduler. Essentially, it + * is `processingStartTime` - `submissionTime`. `-1` means the batch has not + * yet started + * @param processingDelay Time taken for the all jobs of this batch to finish processing from the + * time they started processing. Essentially, it is + * `processingEndTime` - `processingStartTime`. `-1` means the batch has not + * yet completed. + * @param totalDelay Time taken for all the jobs of this batch to finish processing from the time + * they were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + * `-1` means the batch has not yet completed. + * @param numRecords The number of recorders received by the receivers in this batch + * @param outputOperationInfos The output operations in this batch + */ +private[streaming] case class JavaBatchInfo( + batchTime: Time, + streamIdToInputInfo: java.util.Map[Int, JavaStreamInputInfo], + submissionTime: Long, + processingStartTime: Long, + processingEndTime: Long, + schedulingDelay: Long, + processingDelay: Long, + totalDelay: Long, + numRecords: Long, + outputOperationInfos: java.util.Map[Int, JavaOutputOperationInfo]) + +/** + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + * @param metadataDescription description of this input stream + */ +private[streaming] case class JavaStreamInputInfo( + inputStreamId: Int, + numRecords: Long, + metadata: java.util.Map[String, Any], + metadataDescription: String) + +/** + * Class having information about a receiver + */ +private[streaming] case class JavaReceiverInfo( + streamId: Int, + name: String, + active: Boolean, + location: String, + executorId: String, + lastErrorMessage: String, + lastError: String, + lastErrorTime: Long) + +/** + * Class having information on output operations. + * + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing. `-1` means the + * output operation has not yet started + * @param endTime Clock time of when the output operation started processing. `-1` means the output + * operation has not yet completed + * @param failureReason Failure reason if this output operation fails. If the output operation is + * successful, this field is `null`. + */ +private[streaming] case class JavaOutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Long, + endTime: Long, + failureReason: String) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala new file mode 100644 index 000000000000..b109b9f1cbea --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.streaming.scheduler._ + +/** + * A wrapper to convert a [[JavaStreamingListener]] to a [[StreamingListener]]. + */ +private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: JavaStreamingListener) + extends StreamingListener { + + private def toJavaReceiverInfo(receiverInfo: ReceiverInfo): JavaReceiverInfo = { + JavaReceiverInfo( + receiverInfo.streamId, + receiverInfo.name, + receiverInfo.active, + receiverInfo.location, + receiverInfo.executorId, + receiverInfo.lastErrorMessage, + receiverInfo.lastError, + receiverInfo.lastErrorTime + ) + } + + private def toJavaStreamInputInfo(streamInputInfo: StreamInputInfo): JavaStreamInputInfo = { + JavaStreamInputInfo( + streamInputInfo.inputStreamId, + streamInputInfo.numRecords: Long, + streamInputInfo.metadata.asJava, + streamInputInfo.metadataDescription.orNull + ) + } + + private def toJavaOutputOperationInfo( + outputOperationInfo: OutputOperationInfo): JavaOutputOperationInfo = { + JavaOutputOperationInfo( + outputOperationInfo.batchTime, + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description: String, + outputOperationInfo.startTime.getOrElse(-1), + outputOperationInfo.endTime.getOrElse(-1), + outputOperationInfo.failureReason.orNull + ) + } + + private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { + JavaBatchInfo( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.submissionTime, + batchInfo.processingStartTime.getOrElse(-1), + batchInfo.processingEndTime.getOrElse(-1), + batchInfo.schedulingDelay.getOrElse(-1), + batchInfo.processingDelay.getOrElse(-1), + batchInfo.totalDelay.getOrElse(-1), + batchInfo.numRecords, + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + ) + } + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + javaStreamingListener.onReceiverStarted( + new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) + } + + override def onReceiverError(receiverError: StreamingListenerReceiverError): Unit = { + javaStreamingListener.onReceiverError( + new JavaStreamingListenerReceiverError(toJavaReceiverInfo(receiverError.receiverInfo))) + } + + override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped): Unit = { + javaStreamingListener.onReceiverStopped( + new JavaStreamingListenerReceiverStopped(toJavaReceiverInfo(receiverStopped.receiverInfo))) + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + javaStreamingListener.onBatchSubmitted( + new JavaStreamingListenerBatchSubmitted(toJavaBatchInfo(batchSubmitted.batchInfo))) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + javaStreamingListener.onBatchStarted( + new JavaStreamingListenerBatchStarted(toJavaBatchInfo(batchStarted.batchInfo))) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + javaStreamingListener.onBatchCompleted( + new JavaStreamingListenerBatchCompleted(toJavaBatchInfo(batchCompleted.batchInfo))) + } + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + javaStreamingListener.onOutputOperationStarted(new JavaStreamingListenerOutputOperationStarted( + toJavaOutputOperationInfo(outputOperationStarted.outputOperationInfo))) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + javaStreamingListener.onOutputOperationCompleted( + new JavaStreamingListenerOutputOperationCompleted( + toJavaOutputOperationInfo(outputOperationCompleted.outputOperationInfo))) + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d06401245ff1..056248ccc7bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,14 +20,14 @@ package org.apache.spark.streaming.api.python import java.io.{ObjectInputStream, ObjectOutputStream} import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConversions._ + import scala.collection.JavaConverters._ import scala.language.existentials import py4j.GatewayServer +import org.apache.spark.SparkException import org.apache.spark.api.java._ -import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} @@ -41,6 +41,13 @@ import org.apache.spark.util.Utils */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] + + /** + * Get the failure, if any, in the last call to `call`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -49,6 +56,13 @@ private[python] trait PythonTransformFunction { private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] def loads(bytes: Array[Byte]): PythonTransformFunction + + /** + * Get the failure, if any, in the last call to `dumps` or `loads`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -60,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) - .map(_.rdd) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava - Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { - pfunc.call(time.milliseconds, rdds) + callPythonTransformFunction(time.milliseconds, rdds) + } + + private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = { + val resultRDD = pfunc.call(time, rdds) + val failure = pfunc.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + resultRDD } private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -104,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer { /* * Register a serializer from Python, should be called during initialization */ - def register(ser: PythonTransformFunctionSerializer): Unit = { + def register(ser: PythonTransformFunctionSerializer): Unit = synchronized { serializer = ser } - def serialize(func: PythonTransformFunction): Array[Byte] = { + def serialize(func: PythonTransformFunction): Array[Byte] = synchronized { require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - serializer.dumps(id) + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results } - def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized { require(serializer != null, "Serializer has not been registered!") - serializer.loads(bytes) + val pfunc = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + pfunc } } @@ -161,7 +194,7 @@ private[python] object PythonDStream { */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] - rdds.forall(queue.add(_)) + rdds.asScala.foreach(queue.add) queue } } @@ -171,7 +204,7 @@ private[python] object PythonDStream { */ private[python] abstract class PythonDStream( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -188,7 +221,7 @@ private[python] abstract class PythonDStream( */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -207,7 +240,7 @@ private[python] class PythonTransformedDStream ( private[python] class PythonTransformed2DStream( parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -231,9 +264,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - @transient reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -241,7 +284,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } @@ -253,8 +296,8 @@ private[python] class PythonStateDStream( */ private[python] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, + preduceFunc: PythonTransformFunction, + @transient private val pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration) extends PythonDStream(parent, preduceFunc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index f396c347581c..4eb92dd8b105 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, StreamingContext} -import scala.reflect.ClassTag /** * An input stream that always returns the same RDD on each timestep. Useful for testing. @@ -27,6 +28,9 @@ import scala.reflect.ClassTag class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T]) extends InputDStream[T](ssc_) { + require(rdd != null, + "parameter rdd null is illegal, which will lead to NPE in the following transformation") + override def start() {} override def stop() {} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1da0b0a54df0..1a6edf9473d8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -341,7 +341,7 @@ abstract class DStream[T: ClassTag] ( // of RDD generation, else generate nothing. if (isTimeValid(time)) { - val rddOption = createRDDWithLocalProperties(time) { + val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) { // Disable checks for existing output directories in jobs launched by the streaming // scheduler, since we may need to write output to an existing directory during checkpoint // recovery; see SPARK-4835 for more details. We need to have this call here because @@ -373,27 +373,52 @@ abstract class DStream[T: ClassTag] ( /** * Wrap a body of code such that the call site and operation scope * information are passed to the RDDs created in this body properly. - */ - protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + * @param body RDD creation code to execute with certain local properties. + * @param time Current batch time that should be embedded in the scope names + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the inner RDDs generated + * by `body` will be displayed in the UI; only the scope and callsite + * of the DStream operation that generated `this` will be displayed. + */ + protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { val scopeKey = SparkContext.RDD_SCOPE_KEY val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY // Pass this DStream's operation scope and creation site information to RDDs through // thread-local properties in our SparkContext. Since this method may be called from another // DStream, we need to temporarily store any old scope and creation site information to // restore them later after setting our own. - val prevCallSite = ssc.sparkContext.getCallSite() + val prevCallSite = CallSite( + ssc.sparkContext.getLocalProperty(CallSite.SHORT_FORM), + ssc.sparkContext.getLocalProperty(CallSite.LONG_FORM) + ) val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) try { - ssc.sparkContext.setCallSite(creationSite) + if (displayInnerRDDOps) { + // Unset the short form call site, so that generated RDDs get their own + ssc.sparkContext.setLocalProperty(CallSite.SHORT_FORM, null) + ssc.sparkContext.setLocalProperty(CallSite.LONG_FORM, null) + } else { + // Set the callsite, so that the generated RDDs get the DStream's call site and + // the internal RDD call sites do not get displayed + ssc.sparkContext.setCallSite(creationSite) + } + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level // DStream operation name, and (2) share this scope with other DStreams created in the // same operation. Disallow nesting so that low-level Spark primitives do not show up. // TODO: merge callsites with scopes so we can just reuse the code there makeScope(time).foreach { s => ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) - ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + if (displayInnerRDDOps) { + // Allow inner RDDs to add inner scopes + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, null) + } else { + // Do not allow inner RDDs to override the scope set by DStream + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } } body @@ -628,7 +653,7 @@ abstract class DStream[T: ClassTag] ( */ def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { val cleanedF = context.sparkContext.clean(foreachFunc, false) - this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) + foreachRDD((r: RDD[T], t: Time) => cleanedF(r), displayInnerRDDOps = true) } /** @@ -639,7 +664,23 @@ abstract class DStream[T: ClassTag] ( // because the DStream is reachable from the outer object here, and because // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean - new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() + foreachRDD(foreachFunc, displayInnerRDDOps = true) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + * @param foreachFunc foreachRDD function + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * in the `foreachFunc` to be displayed in the UI. If `false`, then + * only the scopes and callsites of `foreachRDD` will override those + * of the RDDs on the display. + */ + private def foreachRDD( + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean): Unit = { + new ForEachDStream(this, + context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register() } /** @@ -730,7 +771,7 @@ abstract class DStream[T: ClassTag] ( // scalastyle:on println } } - new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false) } /** @@ -900,7 +941,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** @@ -913,7 +954,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index c358f5b5bd70..cb5b1f252e90 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * class remembers the information about the files selected in past batches for * a certain duration (say, "remember window") as shown in the figure below. * + * {{{ * |<----- remember window ----->| * ignore threshold --->| |<--- current batch time * |____.____.____.____.____.____| @@ -49,6 +50,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * ---------------------|----|----|----|----|----|----|-----------------------> Time * |____|____|____|____|____|____| * remembered batches + * }}} * * The trailing end of the window is the "ignore threshold" and all files whose mod times * are less than this threshold are assumed to have already been selected and are therefore @@ -59,18 +61,19 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * `isNewFile` for more details. * * This makes some assumptions from the underlying file system that the system is monitoring. - * - The clock of the file system is assumed to synchronized with the clock of the machine running - * the streaming app. - * - If a file is to be visible in the directory listings, it must be visible within a certain - * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be - * selected as the mod time will be less than the ignore threshold when it becomes visible. - * - Once a file is visible, the mod time cannot change. If it does due to appends, then the - * processing semantics are undefined. + * + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index c109ceccc698..4410a9977c87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -22,10 +22,19 @@ import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job import scala.reflect.ClassTag +/** + * An internal DStream used to represent output operations like DStream.foreachRDD. + * @param parent Parent DStream + * @param foreachFunc Function to apply on each RDD generated by the parent DStream + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * by `foreachFunc` will be displayed in the UI; only the scope and + * callsite of `DStream.foreachRDD` will be displayed. + */ private[streaming] class ForEachDStream[T: ClassTag] ( parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean ) extends DStream[Unit](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) @@ -37,8 +46,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => createRDDWithLocalProperties(time) { - ssc.sparkContext.setCallSite(creationSite) + val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) { foreachFunc(rdd, time) } Some(new Job(time, jobFunc)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index a6c4cd220e42..95994c983c0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param ssc_ Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) +abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) extends DStream[T](ssc_) { private[streaming] var lastValidTime: Time = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala new file mode 100644 index 000000000000..706465d4e25d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ + +/** + * :: Experimental :: + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam KeyType Class of the key + * @tparam ValueType Class of the value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data + */ +@Experimental +sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag]( + ssc: StreamingContext) extends DStream[MappedType](ssc) { + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] +} + +/** Internal implementation of the [[MapWithStateDStream]] */ +private[streaming] class MapWithStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag]( + dataStream: DStream[(KeyType, ValueType)], + spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType]) + extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) { + + private val internalStream = + new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec) + + override def slideDuration: Duration = internalStream.slideDuration + + override def dependencies: List[DStream[_]] = List(internalStream) + + override def compute(validTime: Time): Option[RDD[MappedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } } + } + + /** + * Forward the checkpoint interval to the internal DStream that computes the state maps. This + * to make sure that this DStream does not get checkpointed, only the internal stream. + */ + override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = { + internalStream.checkpoint(checkpointInterval) + this + } + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] = { + internalStream.flatMap { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } + + def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass + + def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass + + def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass + + def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass +} + +/** + * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * based on updates to the state. This is the main DStream that implements the `mapWithState` + * operation on DStreams. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the mapWithState operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam E Type of the mapped data + */ +private[streaming] +class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) + extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val mappingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + /** Enable automatic checkpointing */ + override val mustCheckpoint = true + + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { + // Get the previous state or create a new empty state RDD + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + MapWithStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + MapWithStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) + } + + + // Compute the new state RDD with previous state RDD and partitioned data RDD + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + Some(new MapWithStateRDD( + prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) + } +} + +private[streaming] object InternalMapWithStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8..a64a1fe93f40 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.streaming._ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} +import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -350,6 +350,41 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + /** + * :: Experimental :: + * Return a [[MapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. + * + * Example of using `mapWithState`: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data + */ + @Experimental + def mapWithState[StateType: ClassTag, MappedType: ClassTag]( + spec: StateSpec[K, V, StateType, MappedType] + ): MapWithStateDStream[K, V, StateType, MappedType] = { + new MapWithStateDStreamImpl[K, V, StateType, MappedType]( + self, + spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]] + ) + } + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. @@ -692,7 +727,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, + new JobConf(serializableConf.value)) } self.foreachRDD(saveFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 186e1bf03a94..002aac9f4361 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { def getReceiver(): Receiver[T] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a2f5d82a79bd..cd073646370d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.io.{NotSerializableException, ObjectOutputStream} +import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag @@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Time, StreamingContext} private[streaming] class QueueInputDStream[T: ClassTag]( - @transient ssc: StreamingContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] @@ -37,8 +37,13 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def readObject(in: ObjectInputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.") + } + private def writeObject(oos: ObjectOutputStream): Unit = { - throw new NotSerializableException("queueStream doesn't support checkpointing") + logWarning("queueStream doesn't support checkpointing") } override def compute(validTime: Time): Option[RDD[T]] = { @@ -52,12 +57,12 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime) { Some(buffer.head) } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + Some(new UnionRDD(context.sc, buffer.toSeq)) } } else if (defaultRDD != null) { Some(defaultRDD) } else { - None + Some(ssc.sparkContext.emptyRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index e2925b9e03ec..5a9eda7c1277 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 670ef8d296a0..87c20afd5c13 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,12 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -38,7 +38,7 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) +abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) extends InputDStream[T](ssc_) { /** @@ -79,48 +79,63 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont // for this batch val receiverTracker = ssc.scheduler.receiverTracker val blockInfos = receiverTracker.getBlocksOfBatch(validTime).getOrElse(id, Seq.empty) - val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) - if (blockInfos.nonEmpty) { - // Are WAL record handles present with all the blocks - val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + // Create the BlockRDD + createBlockRDD(validTime, blockInfos) + } + } + Some(blockRDD) + } + + private[streaming] def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { - if (areWALRecordHandlesPresent) { - // If all the blocks have WAL record handle, then create a WALBackedBlockRDD - val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray - val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) - } else { - // Else, create a BlockRDD. However, if there are some blocks with WAL info but not - // others then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { - if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - logError("Some blocks do not have Write Ahead Log information; " + - "this is unexpected and data may not be recoverable after driver failures") - } else { - logWarning("Some blocks have Write Ahead Log information; this is unexpected") - } - } - new BlockRDD[T](ssc.sc, blockIds) - } - } else { - // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD - // according to the configuration + if (blockInfos.nonEmpty) { + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, Array.empty, Array.empty, Array.empty) + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") } else { - new BlockRDD[T](ssc.sc, Array.empty) + logWarning("Some blocks have Write Ahead Log information; this is unexpected") } } + val validBlockIds = blockIds.filter { id => + ssc.sparkContext.env.blockManager.master.contains(id) + } + if (validBlockIds.size != blockIds.size) { + logWarning("Some blocks could not be recovered as they were not found in memory. " + + "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "for more details.") + } + new BlockRDD[T](ssc.sc, validBlockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) } } - Some(blockRDD) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 5ce5b7aae6e6..de84e0c9a498 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class SocketInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 5d46ca0715ff..080bc873fa0a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -17,10 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.{PairRDDFunctions, RDD} -import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class TransformedDStream[U: ClassTag] ( parents: Seq[DStream[_]], @@ -37,7 +39,29 @@ class TransformedDStream[U: ClassTag] ( override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { - val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq - Some(transformFunc(parentRDDs, validTime)) + val parentRDDs = parents.map { parent => parent.getOrCompute(validTime).getOrElse( + // Guard out against parent DStream that return None instead of Some(rdd) to avoid NPE + throw new SparkException(s"Couldn't generate RDD from parent at time $validTime")) + } + val transformedRDD = transformFunc(parentRDDs, validTime) + if (transformedRDD == null) { + throw new SparkException("Transform function must not return null. " + + "Return SparkContext.emptyRDD() instead to represent no element " + + "as the result of transformation.") + } + Some(transformedRDD) + } + + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + * This has been overriden to make sure that `displayInnerRDDOps` is always `true`, that is, + * the inner scopes and callsites of RDDs generated in `DStream.transform` are always + * displayed in the UI. + */ + override protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { + super.createRDDWithLocalProperties(time, displayInnerRDDOps = true)(body) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 9405dbaa1232..d73ffdfd84d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.SparkException import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { @@ -41,8 +42,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) val rdds = new ArrayBuffer[RDD[T]]() parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " - + validTime) + case None => throw new SparkException("Could not generate RDD from a parent for unifying at" + + s" time $validTime") } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala new file mode 100644 index 000000000000..fdf61674a37f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.streaming.{Time, StateImpl, State} +import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} +import org.apache.spark.util.Utils +import org.apache.spark._ + +/** + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the mapping function of `mapWithState`. + */ +private[streaming] case class MapWithStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var mappedData: Seq[E]) + +private[streaming] object MapWithStateRDDRecord { + def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + prevRecord: Option[MapWithStateRDDRecord[K, S, E]], + dataIterator: Iterator[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long], + removeTimedoutData: Boolean + ): MapWithStateRDDRecord[K, S, E] = { + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one + val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } + + val mappedData = new ArrayBuffer[E] + val wrappedState = new StateImpl[S]() + + // Call the mapping function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the mapping function + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val returned = mappingFunction(batchTime, key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) + } + mappedData ++= returned + } + + // Get the timed out state records, call the mapping function on each and collect the + // data returned + if (removeTimedoutData && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTimingOutState(state) + val returned = mappingFunction(batchTime, key, None, wrappedState) + mappedData ++= returned + newStateMap.remove(key) + } + } + + MapWithStateRDDRecord(newStateMap, mappedData) + } +} + +/** + * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state + * RDD, and a partitioned keyed-data RDD + */ +private[streaming] class MapWithStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[rdd] var previousSessionRDDPartition: Partition = null + private[rdd] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + + +/** + * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. + * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * function of `mapWithState`. + * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD + * will be created + * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps + * in the `prevStateRDD` to create `this` RDD + * @param mappingFunction The function that will be used to update state and return new data + * @param batchTime The time of the batch to which this RDD belongs to. Use to update + * @param timeoutThresholdTime The time to indicate which keys are timeout + */ +private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]], + private var partitionedDataRDD: RDD[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long] + ) extends RDD[MapWithStateRDDRecord[K, S, E]]( + partitionedDataRDD.sparkContext, + List( + new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(prevStateRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { + + val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + + val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None + val newRecord = MapWithStateRDDRecord.updateRecordWithData( + prevRecord, + dataIterator, + mappingFunction, + batchTime, + timeoutThresholdTime, + removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + ) + Iterator(newRecord) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } + + def setFullScan(): Unit = { + doFullScan = true + } +} + +private[streaming] object MapWithStateRDD { + + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Time): MapWithStateRDD[K, V, S, E] = { + + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): MapWithStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 31ce8e1ec14d..f811784b25c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -61,7 +61,7 @@ class WriteAheadLogBackedBlockRDDPartition( * * * @param sc SparkContext - * @param blockIds Ids of the blocks that contains this RDD's data + * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark * executors). If not, then block lookups by the block ids will be skipped. @@ -73,23 +73,23 @@ class WriteAheadLogBackedBlockRDDPartition( */ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient blockIds: Array[BlockId], - @transient walRecordHandles: Array[WriteAheadLogRecordHandle], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + sc: SparkContext, + @transient private val _blockIds: Array[BlockId], + @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) - extends BlockRDD[T](sc, blockIds) { + extends BlockRDD[T](sc, _blockIds) { require( - blockIds.length == walRecordHandles.length, - s"Number of block Ids (${blockIds.length}) must be " + - s" same as number of WAL record handles (${walRecordHandles.length}})") + _blockIds.length == walRecordHandles.length, + s"Number of block Ids (${_blockIds.length}) must be " + + s" same as number of WAL record handles (${walRecordHandles.length})") require( - isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + isBlockIdValid.isEmpty || isBlockIdValid.length == _blockIds.length, s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + - s" same as number of block Ids (${blockIds.length})") + s" same as number of block Ids (${_blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -99,9 +99,9 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) + new WriteAheadLogBackedBlockRDDPartition(i, _blockIds(i), isValid, walRecordHandles(i)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index cd309788a771..7ec74016a1c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag]( receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { @@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag]( } def onStart(): Unit = { - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + actorSupervisor + logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) } def onStop(): Unit = { - supervisor ! PoisonPill + actorSupervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 92b51ce39234..cc7c04bfc9f6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{Clock, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -69,16 +69,36 @@ private[streaming] trait BlockGeneratorListener { * named blocks at regular intervals. This class starts two threads, * one to periodically start a new batch and prepare the previous batch of as a block, * the other to push the blocks into the block manager. + * + * Note: Do not create BlockGenerator instances directly inside receivers. Use + * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it. */ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, - conf: SparkConf + conf: SparkConf, + clock: Clock = new SystemClock() ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) - private val clock = new SystemClock() + /** + * The BlockGenerator can be in 5 possible states, in the order as follows. + * + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + */ + private object GeneratorState extends Enumeration { + type GeneratorState = Value + val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value + } + import GeneratorState._ + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") @@ -89,70 +109,141 @@ private[streaming] class BlockGenerator( private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] - @volatile private var stopped = false + @volatile private var state = Initialized /** Start block generating and pushing threads. */ - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") + def start(): Unit = synchronized { + if (state == Initialized) { + state = Active + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } else { + throw new SparkException( + s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]") + } } - /** Stop all threads. */ - def stop() { + /** + * Stop everything in the right order such that all the data added is pushed out correctly. + * + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. + */ + def stop(): Unit = { + // Set the state to stop adding data + synchronized { + if (state == Active) { + state = StoppedAddingData + } else { + logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]") + return + } + } + + // Stop generating blocks and set the state for block pushing thread to start draining the queue logInfo("Stopping BlockGenerator") blockIntervalTimer.stop(interruptTimer = false) - stopped = true - logInfo("Waiting for block pushing thread") + synchronized { state = StoppedGeneratingBlocks } + + // Wait for the queue to drain and mark generated as stopped + logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() + synchronized { state = StoppedAll } logInfo("Stopped BlockGenerator") } /** - * Push a single data item into the buffer. All received data items - * will be periodically pushed into BlockManager. + * Push a single data item into the buffer. */ - def addData (data: Any): Unit = synchronized { - waitToPush() - currentBuffer += data + def addData(data: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. + * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { - waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + def addDataWithCallback(data: Any, metadata: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push multiple data items into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. Note that all the data items is guaranteed - * to be present in a single block. + * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items + * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ - def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { - dataIterator.foreach { data => - waitToPush() - currentBuffer += data + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { + if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] + dataIterator.foreach { data => + waitToPush() + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") } - listener.onAddData(dataIterator, metadata) } + def isActive(): Boolean = state == Active + + def isStopped(): Boolean = state == StoppedAll + /** Change the buffer to which single records are added to. */ - private def updateCurrentBuffer(time: Long): Unit = synchronized { + private def updateCurrentBuffer(time: Long): Unit = { try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[Any] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockIntervalMs) - val newBlock = new Block(blockId, newBlockBuffer) - listener.onGenerateBlock(blockId) + var newBlock: Block = null + synchronized { + if (currentBuffer.nonEmpty) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) + listener.onGenerateBlock(blockId) + newBlock = new Block(blockId, newBlockBuffer) + } + } + + if (newBlock != null) { blocksForPushing.put(newBlock) // put is blocking when queue is full - logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { case ie: InterruptedException => @@ -165,18 +256,25 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") + + def areBlocksBeingGenerated: Boolean = synchronized { + state != StoppedGeneratingBlocks + } + try { - while (!stopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + // While blocks are being generated, keep polling for to-be-pushed blocks and push them. + while (areBlocksBeingGenerated) { + Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => } } - // Push out the blocks that are still left + + // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks. logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") val block = blocksForPushing.take() + logDebug(s"Pushing block $block") pushBlock(block) logInfo("Blocks left to push " + blocksForPushing.size()) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index f663def4c051..bca1fbc8fda2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -45,8 +45,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { /** * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. */ - def getCurrentLimit: Long = - rateLimiter.getRate.toLong + def getCurrentLimit: Long = rateLimiter.getRate.toLong /** * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c8dd6e06812d..5f6c5b024085 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -222,7 +222,7 @@ private[streaming] object WriteAheadLogBasedBlockHandler { /** * A utility that will wrap the Iterator to get the count */ -private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { +private[streaming] class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { private var _count = 0 private def isFullyConsumed: Boolean = !iterator.hasNext diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 7504fa44d9fa..2252e28f22af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi @@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * being pushed into Spark's memory. */ def store(dataItem: T) { - executor.pushSingle(dataItem) + supervisor.pushSingle(dataItem) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def store(dataBuffer: ArrayBuffer[T]) { - executor.pushArrayBuffer(dataBuffer, None, None) + supervisor.pushArrayBuffer(dataBuffer, None, None) } /** @@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataBuffer: ArrayBuffer[T], metadata: Any) { - executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator.asScala, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator.asScala, None, None) } /** @@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** @@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - executor.pushBytes(bytes, None, None) + supervisor.pushBytes(bytes, None, None) } /** @@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(bytes: ByteBuffer, metadata: Any) { - executor.pushBytes(bytes, Some(metadata), None) + supervisor.pushBytes(bytes, Some(metadata), None) } /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { - executor.reportError(message, throwable) + supervisor.reportError(message, throwable) } /** @@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` will be reported to the driver. */ def restart(message: String) { - executor.restartReceiver(message) + supervisor.restartReceiver(message) } /** @@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` and `exception` will be reported to the driver. */ def restart(message: String, error: Throwable) { - executor.restartReceiver(message, Some(error)) + supervisor.restartReceiver(message, Some(error)) } /** @@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * in a background thread. */ def restart(message: String, error: Throwable, millisecond: Int) { - executor.restartReceiver(message, Some(error), millisecond) + supervisor.restartReceiver(message, Some(error), millisecond) } /** Stop the receiver completely. */ def stop(message: String) { - executor.stop(message, None) + supervisor.stop(message, None) } /** Stop the receiver completely due to an exception */ def stop(message: String, error: Throwable) { - executor.stop(message, Some(error)) + supervisor.stop(message, Some(error)) } /** Check if the receiver has started or not. */ def isStarted(): Boolean = { - executor.isReceiverStarted() + supervisor.isReceiverStarted() } /** @@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * the receiving of data should be stopped. */ def isStopped(): Boolean = { - executor.isReceiverStopped() + supervisor.isReceiverStopped() } /** @@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] var executor_ : ReceiverSupervisor = null + @transient private var _supervisor : ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ private[streaming] def setReceiverId(id_ : Int) { @@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Attach Network Receiver executor to this receiver. */ - private[streaming] def attachExecutor(exec: ReceiverSupervisor) { - assert(executor_ == null) - executor_ = exec + private[streaming] def attachSupervisor(exec: ReceiverSupervisor) { + assert(_supervisor == null) + _supervisor = exec } - /** Get the attached executor. */ - private def executor: ReceiverSupervisor = { - assert(executor_ != null, "Executor has not been attached to this receiver") - executor_ + /** Get the attached supervisor. */ + private[streaming] def supervisor: ReceiverSupervisor = { + assert(_supervisor != null, + "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "some computation in the receiver before the Receiver.onStart() has been called.") + _supervisor } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index e98017a63756..158d1ba2f183 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -44,8 +44,8 @@ private[streaming] abstract class ReceiverSupervisor( } import ReceiverState._ - // Attach the executor to the receiver - receiver.attachExecutor(this) + // Attach the supervisor to the receiver + receiver.attachSupervisor(this) private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) @@ -60,7 +60,7 @@ private[streaming] abstract class ReceiverSupervisor( private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) /** The current maximum rate limit for this receiver. */ - private[streaming] def getCurrentRateLimit: Option[Long] = None + private[streaming] def getCurrentRateLimit: Long = Long.MaxValue /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null @@ -92,13 +92,30 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ) + /** + * Create a custom [[BlockGenerator]] that the receiver implementation can directly control + * using their provided [[BlockGeneratorListener]]. + * + * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl` + * will take care of it. + */ + def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator + /** Report errors. */ def reportError(message: String, throwable: Throwable) - /** Called when supervisor is started */ + /** + * Called when supervisor is started. + * Note that this must be called before the receiver.onStart() is called to ensure + * things like [[BlockGenerator]]s are started before the receiver starts sending data. + */ protected def onStart() { } - /** Called when supervisor is stopped */ + /** + * Called when supervisor is stopped. + * Note that this must be called after the receiver.onStop() is called to ensure + * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data. + */ protected def onStop(message: String, error: Option[Throwable]) { } /** Called when receiver is started. Return true if the driver accepts us */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 0d802f83549a..167f56aa4228 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -46,7 +47,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { - private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val host = SparkEnv.get.blockManager.blockManagerId.host + private val executorId = SparkEnv.get.blockManager.blockManagerId.executorId private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { @@ -81,15 +83,20 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) case UpdateRateLimit(eps) => logInfo(s"Received a new rate limit: $eps.") - blockGenerator.updateRate(eps) + registeredBlockGenerators.foreach { bg => + bg.updateRate(eps) + } } }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) + private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] + with mutable.SynchronizedBuffer[BlockGenerator] + /** Divides received data records into data blocks for pushing in BlockManager. */ - private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + private val defaultBlockGeneratorListener = new BlockGeneratorListener { def onAddData(data: Any, metadata: Any): Unit = { } def onGenerateBlock(blockId: StreamBlockId): Unit = { } @@ -101,14 +108,15 @@ private[streaming] class ReceiverSupervisorImpl( def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { pushArrayBuffer(arrayBuffer, None, Some(blockId)) } - }, streamId, env.conf) + } + private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener) - override private[streaming] def getCurrentRateLimit: Option[Long] = - Some(blockGenerator.getCurrentLimit) + /** Get the current rate limit of the default block generator */ + override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator.addData(data) + defaultBlockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ @@ -162,17 +170,17 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - blockGenerator.start() + registeredBlockGenerators.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - blockGenerator.stop() + registeredBlockGenerators.foreach { _.stop() } env.rpcEnv.stop(endpoint) } override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, hostPort, endpoint) + streamId, receiver.getClass.getSimpleName, host, executorId, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } @@ -183,6 +191,16 @@ private[streaming] class ReceiverSupervisorImpl( logInfo("Stopped receiver " + streamId) } + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + // Cleanup BlockGenerators that have already been stopped + registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + + val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) + registeredBlockGenerators += newBlockGenerator + newBlockGenerator + } + /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 9922b6bc1201..436eb0a56614 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.streaming.Time * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing * @param processingEndTime Clock time of when the last job of this batch finished processing + * @param outputOperationInfos The output operations in this batch */ @DeveloperApi case class BatchInfo( @@ -36,7 +37,8 @@ case class BatchInfo( streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], - processingEndTime: Option[Long] + processingEndTime: Option[Long], + outputOperationInfos: Map[Int, OutputOperationInfo] ) { @deprecated("Use streamIdToInputInfo instead", "1.5.0") @@ -67,4 +69,5 @@ case class BatchInfo( * The number of recorders received by the receivers in this batch. */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 363c03d431f0..deb15d075975 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + s"$batchTime is already added into InputInfoTracker, this is a illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 3c481bf3491f..ab1b3565fcc1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming.scheduler +import scala.util.{Failure, Try} + import org.apache.spark.streaming.Time -import scala.util.Try +import org.apache.spark.util.{Utils, CallSite} /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -29,6 +31,9 @@ class Job(val time: Time, func: () => _) { private var _outputOpId: Int = _ private var isSet = false private var _result: Try[_] = null + private var _callSite: CallSite = null + private var _startTime: Option[Long] = None + private var _endTime: Option[Long] = None def run() { _result = Try(func()) @@ -70,5 +75,29 @@ class Job(val time: Time, func: () => _) { _outputOpId = outputOpId } + def setCallSite(callSite: CallSite): Unit = { + _callSite = callSite + } + + def callSite: CallSite = _callSite + + def setStartTime(startTime: Long): Unit = { + _startTime = Some(startTime) + } + + def setEndTime(endTime: Long): Unit = { + _endTime = Some(endTime) + } + + def toOutputOperationInfo: OutputOperationInfo = { + val failureReason = if (_result != null && _result.isFailure) { + Some(Utils.exceptionString(_result.asInstanceOf[Failure[_]].exception)) + } else { + None + } + OutputOperationInfo( + time, outputOpId, callSite.shortForm, callSite.longForm, _startTime, _endTime, failureReason) + } + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f2117ada61c..8dfdc1f57b40 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -79,6 +79,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { def start(): Unit = synchronized { if (eventLoop != null) return // generator has already been started + // Call checkpointWriter here to initialize it before eventLoop uses it to avoid a deadlock. + // See SPARK-10125 + checkpointWriter + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) @@ -216,7 +220,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + val timesToReschedule = (pendingTimes ++ downTimes).filter { _ < restartTime } + .distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach { time => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 58bdda7794bf..1ed6fb0aa9d5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,20 +17,21 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ -import scala.util.{Failure, Success} +import scala.collection.JavaConverters._ +import scala.util.Failure import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ -import org.apache.spark.util.EventLoop +import org.apache.spark.streaming.ui.UIUtils +import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} private[scheduler] sealed trait JobSchedulerEvent -private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent -private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent +private[scheduler] case class JobStarted(job: Job, startTime: Long) extends JobSchedulerEvent +private[scheduler] case class JobCompleted(job: Job, completedTime: Long) extends JobSchedulerEvent private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent /** @@ -40,9 +41,12 @@ private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends J private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - private val jobSets = new ConcurrentHashMap[Time, JobSet] + // Use of ConcurrentHashMap.keySet later causes an odd runtime problem due to Java 7/8 diff + // https://gist.github.com/AlainODea/1375759b8720a3f9f094 + private val jobSets: java.util.Map[Time, JobSet] = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = + ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -84,8 +88,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") - // First, stop receiving - receiverTracker.stop(processAllReceivedData) + if (receiverTracker != null) { + // First, stop receiving + receiverTracker.stop(processAllReceivedData) + } // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. @@ -125,7 +131,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def getPendingTimes(): Seq[Time] = { - jobSets.keySet.toSeq + jobSets.asScala.keys.toSeq } def reportError(msg: String, e: Throwable) { @@ -139,8 +145,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def processEvent(event: JobSchedulerEvent) { try { event match { - case JobStarted(job) => handleJobStart(job) - case JobCompleted(job) => handleJobCompletion(job) + case JobStarted(job, startTime) => handleJobStart(job, startTime) + case JobCompleted(job, completedTime) => handleJobCompletion(job, completedTime) case ErrorReported(m, e) => handleError(m, e) } } catch { @@ -149,7 +155,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } - private def handleJobStart(job: Job) { + private def handleJobStart(job: Job, startTime: Long) { val jobSet = jobSets.get(job.time) val isFirstJobOfJobSet = !jobSet.hasStarted jobSet.handleJobStart(job) @@ -158,26 +164,30 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } + job.setStartTime(startTime) + listenerBus.post(StreamingListenerOutputOperationStarted(job.toOutputOperationInfo)) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } - private def handleJobCompletion(job: Job) { + private def handleJobCompletion(job: Job, completedTime: Long) { + val jobSet = jobSets.get(job.time) + jobSet.handleJobCompletion(job) + job.setEndTime(completedTime) + listenerBus.post(StreamingListenerOutputOperationCompleted(job.toOutputOperationInfo)) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + jobGenerator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) + } job.result match { - case Success(_) => - val jobSet = jobSets.get(job.time) - jobSet.handleJobCompletion(job) - logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) - if (jobSet.hasCompleted) { - jobSets.remove(jobSet.time) - jobGenerator.onBatchCompletion(jobSet.time) - logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( - jobSet.totalDelay / 1000.0, jobSet.time.toString, - jobSet.processingDelay / 1000.0 - )) - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) - } case Failure(e) => reportError("Error running job " + job, e) + case _ => } } @@ -187,18 +197,39 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } private class JobHandler(job: Job) extends Runnable with Logging { + import JobScheduler._ + def run() { - ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) - ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) try { - eventLoop.post(JobStarted(job)) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + val formattedTime = UIUtils.formatBatchTime( + job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" + val batchLinkText = s"[output operation ${job.outputOpId}, batch time ${formattedTime}]" + + ssc.sc.setJobDescription( + s"""Streaming job from $batchLinkText""") + ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) + ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + + // We need to assign `eventLoop` to a temp variable. Otherwise, because + // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then + // it's possible that when `post` is called, `eventLoop` happens to null. + var _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobStarted(job, clock.getTimeMillis())) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobCompleted(job, clock.getTimeMillis())) + } + } else { + // JobScheduler has been stopped. } - eventLoop.post(JobCompleted(job)) } finally { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 95833efc9417..f76300351e3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.HashSet +import scala.util.Failure import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils /** Class representing a set of Jobs * belong to the same batch. @@ -62,12 +64,13 @@ case class JobSet( } def toBatchInfo: BatchInfo = { - new BatchInfo( + BatchInfo( time, streamIdToInputInfo, submissionTime, - if (processingStartTime >= 0 ) Some(processingStartTime) else None, - if (processingEndTime >= 0 ) Some(processingEndTime) else None + if (processingStartTime >= 0) Some(processingStartTime) else None, + if (processingEndTime >= 0) Some(processingEndTime) else None, + jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala new file mode 100644 index 000000000000..137e512a670d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.Time + +/** + * :: DeveloperApi :: + * Class having information on output operations. + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing + * @param endTime Clock time of when the output operation started processing + * @param failureReason Failure reason if this output operation fails + */ +@DeveloperApi +case class OutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Option[Long], + endTime: Option[Long], + failureReason: Option[String]) { + + /** + * Return the duration of this output operation. + */ + def duration: Option[Long] = for (s <- startTime; e <- endTime) yield e - s +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala index 882ca0676b6a..a46c0c1b25e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -76,9 +76,9 @@ private[streaming] abstract class RateController(val streamUID: Int, rateEstimat val elements = batchCompleted.batchInfo.streamIdToInputInfo for { - processingEnd <- batchCompleted.batchInfo.processingEndTime; - workDelay <- batchCompleted.batchInfo.processingDelay; - waitDelay <- batchCompleted.batchInfo.schedulingDelay; + processingEnd <- batchCompleted.batchInfo.processingEndTime + workDelay <- batchCompleted.batchInfo.processingDelay + waitDelay <- batchCompleted.batchInfo.schedulingDelay elems <- elements.get(streamUID).map(_.numRecords) } computeAndPublish(processingEnd, elems, workDelay, waitDelay) } @@ -86,5 +86,5 @@ private[streaming] abstract class RateController(val streamUID: Int, rateEstimat object RateController { def isBackPressureEnabled(conf: SparkConf): Boolean = - conf.getBoolean("spark.streaming.backpressure.enable", false) + conf.getBoolean("spark.streaming.backpressure.enabled", false) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 7720259a5d79..4dab64d696b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -19,16 +19,19 @@ package org.apache.spark.streaming.scheduler import java.nio.ByteBuffer +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} +import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -40,7 +43,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: private[streaming] case class BatchCleanupEvent(times: Seq[Time]) extends ReceivedBlockTrackerLogEvent - /** Class representing the blocks of all the streams allocated to a batch */ private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { @@ -81,15 +83,22 @@ private[streaming] class ReceivedBlockTracker( } /** Add received block. This event will get written to the write ahead log (if enabled). */ - def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { try { - writeToLog(BlockAdditionEvent(receivedBlockInfo)) - getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug(s"Stream ${receivedBlockInfo.streamId} received " + - s"block ${receivedBlockInfo.blockStoreResult.blockId}") - true + val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo)) + if (writeResult) { + synchronized { + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + } else { + logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " + + s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.") + } + writeResult } catch { - case e: Exception => + case NonFatal(e) => logError(s"Error adding block $receivedBlockInfo", e) false } @@ -105,10 +114,12 @@ private[streaming] class ReceivedBlockTracker( (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) - writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) - timeToAllocatedBlocks(batchTime) = allocatedBlocks - lastAllocatedBatchTime = batchTime - allocatedBlocks + if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime + } else { + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + } } else { // This situation occurs when: // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, @@ -156,9 +167,12 @@ private[streaming] class ReceivedBlockTracker( require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) - writeToLog(BatchCleanupEvent(timesToCleanup)) - timeToAllocatedBlocks --= timesToCleanup - writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + if (writeToLog(BatchCleanupEvent(timesToCleanup))) { + timeToAllocatedBlocks --= timesToCleanup + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + } else { + logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.") + } } /** Stop the block tracker. */ @@ -184,8 +198,8 @@ private[streaming] class ReceivedBlockTracker( logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } - lastAllocatedBatchTime = batchTime timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime } // Cleanup the batch allocations @@ -196,10 +210,10 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - import scala.collection.JavaConversions._ - writeAheadLog.readAll().foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + writeAheadLog.readAll().asScala.foreach { byteBuffer => + logInfo("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent]( + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => @@ -212,12 +226,20 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { - logDebug(s"Writing to log $record") - writeAheadLogOption.foreach { logManager => - logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) + logTrace(s"Writing record: $record") + try { + writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)), + clock.getTimeMillis()) + true + } catch { + case NonFatal(e) => + logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e) + false } + } else { + true } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index 59df892397fe..3b35964114c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -30,6 +30,7 @@ case class ReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index ef5b687b5831..391a461f0812 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -20,8 +20,40 @@ package org.apache.spark.streaming.scheduler import scala.collection.Map import scala.collection.mutable +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.receiver.Receiver +/** + * A class that tries to schedule receivers with evenly distributed. There are two phases for + * scheduling receivers. + * + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker + * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should + * check if the location of this receiver is one of the scheduled locations, if not, the register + * will be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * locations mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that + * are still alive in the list of scheduled locations, then use them to launch the receiver + * job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` + * should not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it + * should clear it. Then when this receiver is registering, we can know this is a local + * scheduling, and `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if + * the launching location is matching. + * + * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, + * otherwise do local scheduling. + */ private[streaming] class ReceiverSchedulingPolicy { /** @@ -38,9 +70,12 @@ private[streaming] class ReceiverSchedulingPolicy { * * * This method is called when we start to launch receivers at the first time. + * + * @return a map for receivers and their scheduled locations */ def scheduleReceivers( - receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + receivers: Seq[Receiver[_]], + executors: Seq[ExecutorCacheTaskLocation]): Map[Int, Seq[TaskLocation]] = { if (receivers.isEmpty) { return Map.empty } @@ -49,16 +84,16 @@ private[streaming] class ReceiverSchedulingPolicy { return receivers.map(_.streamId -> Seq.empty).toMap } - val hostToExecutors = executors.groupBy(_.split(":")(0)) - val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val hostToExecutors = executors.groupBy(_.host) + val scheduledLocations = Array.fill(receivers.length)(new mutable.ArrayBuffer[TaskLocation]) + val numReceiversOnExecutor = mutable.HashMap[ExecutorCacheTaskLocation, Int]() // Set the initial value to 0 executors.foreach(e => numReceiversOnExecutor(e) = 0) // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. for (i <- 0 until receivers.length) { - // Note: preferredLocation is host but executors are host:port + // Note: preferredLocation is host but executors are host_executorId receivers(i).preferredLocation.foreach { host => hostToExecutors.get(host) match { case Some(executorsOnHost) => @@ -66,7 +101,7 @@ private[streaming] class ReceiverSchedulingPolicy { // this host val leastScheduledExecutor = executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) - scheduledExecutors(i) += leastScheduledExecutor + scheduledLocations(i) += leastScheduledExecutor numReceiversOnExecutor(leastScheduledExecutor) = numReceiversOnExecutor(leastScheduledExecutor) + 1 case None => @@ -75,17 +110,20 @@ private[streaming] class ReceiverSchedulingPolicy { // 1. This executor is not up. But it may be up later. // 2. This executor is dead, or it's not a host in the cluster. // Currently, simply add host to the scheduled executors. - scheduledExecutors(i) += host + + // Note: host could be `HDFSCacheTaskLocation`, so use `TaskLocation.apply` to handle + // this case + scheduledLocations(i) += TaskLocation(host) } } } // For those receivers that don't have preferredLocation, make sure we assign at least one // executor to them. - for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + for (scheduledLocationsForOneReceiver <- scheduledLocations.filter(_.isEmpty)) { // Select the executor that has the least receivers val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) - scheduledExecutorsForOneReceiver += leastScheduledExecutor + scheduledLocationsForOneReceiver += leastScheduledExecutor numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 } @@ -93,23 +131,22 @@ private[streaming] class ReceiverSchedulingPolicy { val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) for (executor <- idleExecutors) { // Assign an idle executor to the receiver that has least candidate executors. - val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + val leastScheduledExecutors = scheduledLocations.minBy(_.size) leastScheduledExecutors += executor } - receivers.map(_.streamId).zip(scheduledExecutors).toMap + receivers.map(_.streamId).zip(scheduledLocations).toMap } /** - * Return a list of candidate executors to run the receiver. If the list is empty, the caller can - * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require - * returning `preferredNumExecutors` executors if possible. + * Return a list of candidate locations to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. * * This method tries to balance executors' load. Here is the approach to schedule executors * for a receiver. *
          *
        1. - * If preferredLocation is set, preferredLocation should be one of the candidate executors. + * If preferredLocation is set, preferredLocation should be one of the candidate locations. *
        2. *
        3. * Every executor will be assigned to a weight according to the receivers running or @@ -122,9 +159,8 @@ private[streaming] class ReceiverSchedulingPolicy { * If a receiver is scheduled to an executor but has not yet run, it contributes * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
        4. * - * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), - * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options - * according to the weights. + * At last, if there are any idle executors (weight = 0), returns all idle executors. + * Otherwise, returns the executors that have the minimum weight. * *
        * @@ -134,38 +170,58 @@ private[streaming] class ReceiverSchedulingPolicy { receiverId: Int, preferredLocation: Option[String], receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], - executors: Seq[String], - preferredNumExecutors: Int = 3): Seq[String] = { + executors: Seq[ExecutorCacheTaskLocation]): Seq[TaskLocation] = { if (executors.isEmpty) { return Seq.empty } // Always try to schedule to the preferred locations - val scheduledExecutors = mutable.Set[String]() - scheduledExecutors ++= preferredLocation - - val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => - receiverTrackingInfo.state match { - case ReceiverState.INACTIVE => Nil - case ReceiverState.SCHEDULED => - val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get - // The probability that a scheduled receiver will run in an executor is - // 1.0 / scheduledLocations.size - scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) - case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) - } - }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + val scheduledLocations = mutable.Set[TaskLocation]() + // Note: preferredLocation could be `HDFSCacheTaskLocation`, so use `TaskLocation.apply` to + // handle this case + scheduledLocations ++= preferredLocation.map(TaskLocation(_)) + + val executorWeights: Map[ExecutorCacheTaskLocation, Double] = { + receiverTrackingInfoMap.values.flatMap(convertReceiverTrackingInfoToExecutorWeights) + .groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + } - val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq - if (idleExecutors.size >= preferredNumExecutors) { - // If there are more than `preferredNumExecutors` idle executors, return all of them - scheduledExecutors ++= idleExecutors + val idleExecutors = executors.toSet -- executorWeights.keys + if (idleExecutors.nonEmpty) { + scheduledLocations ++= idleExecutors } else { - // If there are less than `preferredNumExecutors` idle executors, return 3 best options - scheduledExecutors ++= idleExecutors - val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) - scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + // There is no idle executor. So select all executors that have the minimum weight. + val sortedExecutors = executorWeights.toSeq.sortBy(_._2) + if (sortedExecutors.nonEmpty) { + val minWeight = sortedExecutors(0)._2 + scheduledLocations ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + } else { + // This should not happen since "executors" is not empty + } + } + scheduledLocations.toSeq + } + + /** + * This method tries to convert a receiver tracking info to executor weights. Every executor will + * be assigned to a weight according to the receivers running or scheduling on it: + * + * - If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + * - If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight. + */ + private def convertReceiverTrackingInfoToExecutorWeights( + receiverTrackingInfo: ReceiverTrackingInfo): Seq[(ExecutorCacheTaskLocation, Double)] = { + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledLocations = receiverTrackingInfo.scheduledLocations.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledLocations.filter(_.isInstanceOf[ExecutorCacheTaskLocation]).map { location => + location.asInstanceOf[ExecutorCacheTaskLocation] -> (1.0 / scheduledLocations.size) + } + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) } - scheduledExecutors.toSeq } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index e076fb5ea174..ea5d12b50fcc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,20 +17,21 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, CountDownLatch} +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.language.existentials import scala.util.{Failure, Success} -import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{TaskLocation, ExecutorCacheTaskLocation} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils, Utils} /** Enumeration to identify current state of a Receiver */ @@ -47,7 +48,8 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - hostPort: String, + host: String, + executorId: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -235,7 +237,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - hostPort: String, + host: String, + executorId: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress ): Boolean = { @@ -244,8 +247,26 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } if (isTrackerStopping || isTrackerStopped) { - false - } else if (!scheduleReceiver(streamId).contains(hostPort)) { + return false + } + + val scheduledLocations = receiverTrackingInfos(streamId).scheduledLocations + val acceptableExecutors = if (scheduledLocations.nonEmpty) { + // This receiver is registering and it's scheduled by + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledLocations" to check it. + scheduledLocations.get + } else { + // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling + // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. + scheduleReceiver(streamId) + } + + def isAcceptable: Boolean = acceptableExecutors.exists { + case loc: ExecutorCacheTaskLocation => loc.executorId == executorId + case loc: TaskLocation => loc.host == host + } + + if (!isAcceptable) { // Refuse it since it's scheduled to a wrong executor false } else { @@ -253,8 +274,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false val receiverTrackingInfo = ReceiverTrackingInfo( streamId, ReceiverState.ACTIVE, - scheduledExecutors = None, - runningExecutor = Some(hostPort), + scheduledLocations = None, + runningExecutor = Some(ExecutorCacheTaskLocation(host, executorId)), name = Some(name), endpoint = Some(receiverEndpoint)) receiverTrackingInfos.put(streamId, receiverTrackingInfo) @@ -278,7 +299,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ReceiverTrackingInfo( streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverTrackingInfos -= streamId + receiverTrackingInfos(streamId) = newReceiverTrackingInfo listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" @@ -325,25 +346,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } - private def scheduleReceiver(receiverId: Int): Seq[String] = { + private def scheduleReceiver(receiverId: Int): Seq[TaskLocation] = { val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) - val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + val scheduledLocations = schedulingPolicy.rescheduleReceiver( receiverId, preferredLocation, receiverTrackingInfos, getExecutors) - updateReceiverScheduledExecutors(receiverId, scheduledExecutors) - scheduledExecutors + updateReceiverScheduledExecutors(receiverId, scheduledLocations) + scheduledLocations } private def updateReceiverScheduledExecutors( - receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + receiverId: Int, scheduledLocations: Seq[TaskLocation]): Unit = { val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { case Some(oldInfo) => oldInfo.copy(state = ReceiverState.SCHEDULED, - scheduledExecutors = Some(scheduledExecutors)) + scheduledLocations = Some(scheduledLocations)) case None => ReceiverTrackingInfo( receiverId, ReceiverState.SCHEDULED, - Some(scheduledExecutors), + Some(scheduledLocations), runningExecutor = None) } receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) @@ -357,13 +378,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** * Get the list of executors excluding driver */ - private def getExecutors: Seq[String] = { + private def getExecutors: Seq[ExecutorCacheTaskLocation] = { if (ssc.sc.isLocal) { - Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + val blockManagerId = ssc.sparkContext.env.blockManager.blockManagerId + Seq(ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId)) } else { ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location - }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + }.map { case (blockManagerId, _) => + ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId) + }.toSeq } } @@ -413,26 +437,44 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) + + private val walBatchingThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) + + @volatile private var active: Boolean = true override def receive: PartialFunction[Any, Unit] = { // Local messages case StartAllReceivers(receivers) => - val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + val scheduledLocations = schedulingPolicy.scheduleReceivers(receivers, getExecutors) for (receiver <- receivers) { - val executors = scheduledExecutors(receiver.streamId) + val executors = scheduledLocations(receiver.streamId) updateReceiverScheduledExecutors(receiver.streamId, executors) receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation startReceiver(receiver, executors) } case RestartReceiver(receiver) => - val scheduledExecutors = schedulingPolicy.rescheduleReceiver( - receiver.streamId, - receiver.preferredLocation, - receiverTrackingInfos, - getExecutors) - updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) - startReceiver(receiver, scheduledExecutors) + // Old scheduled executors minus the ones that are not active any more + val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) + val scheduledLocations = if (oldScheduledExecutors.nonEmpty) { + // Try global scheduling again + oldScheduledExecutors + } else { + val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) + // Clear "scheduledLocations" to indicate we are going to do local scheduling + val newReceiverInfo = oldReceiverInfo.copy( + state = ReceiverState.INACTIVE, scheduledLocations = None) + receiverTrackingInfos(receiver.streamId) = newReceiverInfo + schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + } + // Assume there is one receiver restarting at one time, so we don't need to update + // receiverTrackingInfos + startReceiver(receiver, scheduledLocations) case c: CleanupOldBlocks => receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) case UpdateReceiverRateLimit(streamUID, newRate) => @@ -446,30 +488,70 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { // Remote messages - case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => + case RegisterReceiver(streamId, typ, host, executorId, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => - context.reply(addBlock(receivedBlockInfo)) + if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) { + walBatchingThreadPool.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + if (active) { + context.reply(addBlock(receivedBlockInfo)) + } else { + throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + } + } + }) + } else { + context.reply(addBlock(receivedBlockInfo)) + } case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) // Local messages case AllReceiverIds => - context.reply(receiverTrackingInfos.keys.toSeq) + context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() context.reply(true) } + /** + * Return the stored scheduled executors that are still alive. + */ + private def getStoredScheduledExecutors(receiverId: Int): Seq[TaskLocation] = { + if (receiverTrackingInfos.contains(receiverId)) { + val scheduledLocations = receiverTrackingInfos(receiverId).scheduledLocations + if (scheduledLocations.nonEmpty) { + val executors = getExecutors.toSet + // Only return the alive executors + scheduledLocations.get.filter { + case loc: ExecutorCacheTaskLocation => executors(loc) + case loc: TaskLocation => true + } + } else { + Nil + } + } else { + Nil + } + } + /** * Start a receiver along with its scheduled executors */ - private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + private def startReceiver( + receiver: Receiver[_], + scheduledLocations: Seq[TaskLocation]): Unit = { + def shouldStartReceiver: Boolean = { + // It's okay to start when trackerState is Initialized or Started + !(isTrackerStopping || isTrackerStopped) + } + val receiverId = receiver.streamId - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) return } @@ -479,29 +561,49 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) + val startReceiverFunc: Iterator[Receiver[_]] => Unit = + (iterator: Iterator[Receiver[_]]) => { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } - // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + // Create the RDD using the scheduledLocations to run the receiver in a Spark job val receiverRDD: RDD[Receiver[_]] = - if (scheduledExecutors.isEmpty) { + if (scheduledLocations.isEmpty) { ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) + val preferredLocations = scheduledLocations.map(_.toString).distinct + ssc.sc.makeRDD(Seq(receiver -> preferredLocations)) } receiverRDD.setName(s"Receiver $receiverId") + ssc.sparkContext.setJobDescription(s"Streaming job running receiver $receiverId") + ssc.sparkContext.setCallSite(Option(ssc.getStartSite()).getOrElse(Utils.getCallSite())) + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) // We will keep restarting the receiver job until ReceiverTracker is stopped future.onComplete { case Success(_) => - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) } else { logInfo(s"Restarting Receiver $receiverId") self.send(RestartReceiver(receiver)) } case Failure(e) => - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) } else { logError("Receiver has been stopped. Try to restart it.", e) @@ -514,6 +616,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def onStop(): Unit = { submitJobThreadPool.shutdownNow() + active = false + walBatchingThreadPool.shutdown() } /** @@ -536,31 +640,3 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - -/** - * Function to start the receiver on the worker node. Use a class instead of closure to avoid - * the serialization issue. - */ -private class StartReceiverFunc( - checkpointDirOption: Option[String], - serializableHadoopConf: SerializableConfiguration) - extends (Iterator[Receiver[_]] => Unit) with Serializable { - - override def apply(iterator: Iterator[Receiver[_]]): Unit = { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - if (TaskContext.get().attemptNumber() == 0) { - val receiver = iterator.next() - assert(iterator.hasNext == false) - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } else { - // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. - } - } - -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index 043ff4d0ff05..4dc5bb9c3bfb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.scheduler.ReceiverState._ private[streaming] case class ReceiverErrorInfo( @@ -28,7 +29,7 @@ private[streaming] case class ReceiverErrorInfo( * * @param receiverId the unique receiver id * @param state the current Receiver state - * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param scheduledLocations the scheduled locations provided by ReceiverSchedulingPolicy * @param runningExecutor the running executor if the receiver is active * @param name the receiver name * @param endpoint the receiver endpoint. It can be used to send messages to the receiver @@ -37,8 +38,8 @@ private[streaming] case class ReceiverErrorInfo( private[streaming] case class ReceiverTrackingInfo( receiverId: Int, state: ReceiverState, - scheduledExecutors: Option[Seq[String]], - runningExecutor: Option[String], + scheduledLocations: Option[Seq[TaskLocation]], + runningExecutor: Option[ExecutorCacheTaskLocation], name: Option[String] = None, endpoint: Option[RpcEndpointRef] = None, errorInfo: Option[ReceiverErrorInfo] = None) { @@ -47,7 +48,8 @@ private[streaming] case class ReceiverTrackingInfo( receiverId, name.getOrElse(""), state == ReceiverState.ACTIVE, - location = runningExecutor.getOrElse(""), + location = runningExecutor.map(_.host).getOrElse(""), + executorId = runningExecutor.map(_.executorId).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 74dbba453f02..d19bdbb443c5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -38,6 +38,14 @@ case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends Streami @DeveloperApi case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +@DeveloperApi +case class StreamingListenerOutputOperationStarted(outputOperationInfo: OutputOperationInfo) + extends StreamingListenerEvent + +@DeveloperApi +case class StreamingListenerOutputOperationCompleted(outputOperationInfo: OutputOperationInfo) + extends StreamingListenerEvent + @DeveloperApi case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo) extends StreamingListenerEvent @@ -75,6 +83,14 @@ trait StreamingListener { /** Called when processing of a batch of jobs has completed. */ def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted) { } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index b07d6cf347ca..ca111bb636ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -43,6 +43,10 @@ private[spark] class StreamingListenerBus listener.onBatchStarted(batchStarted) case batchCompleted: StreamingListenerBatchCompleted => listener.onBatchCompleted(batchCompleted) + case outputOperationStarted: StreamingListenerOutputOperationStarted => + listener.onOutputOperationStarted(outputOperationStarted) + case outputOperationCompleted: StreamingListenerOutputOperationCompleted => + listener.onOutputOperationCompleted(outputOperationCompleted) case _ => } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index 6ae56a68ad88..84a3ca9d74e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.scheduler.rate +import org.apache.spark.Logging + /** * Implements a proportional-integral-derivative (PID) controller which acts on * the speed of ingestion of elements into Spark Streaming. A PID controller works @@ -26,7 +28,7 @@ package org.apache.spark.streaming.scheduler.rate * * @see https://en.wikipedia.org/wiki/PID_controller * - * @param batchDurationMillis the batch duration, in milliseconds + * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current * error. This term usually provides the bulk of correction and should be positive or zero. * A value too large would make the controller overshoot the setpoint, while a small value @@ -39,13 +41,17 @@ package org.apache.spark.streaming.scheduler.rate * of future errors, based on current rate of change. This value should be positive or 0. * This term is not used very often, as it impacts stability of the system. The default * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. */ private[streaming] class PIDRateEstimator( batchIntervalMillis: Long, - proportional: Double = 1D, - integral: Double = .2D, - derivative: Double = 0D) - extends RateEstimator { + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { private var firstRun: Boolean = true private var latestTime: Long = -1L @@ -64,16 +70,23 @@ private[streaming] class PIDRateEstimator( require( derivative >= 0, s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") - def compute(time: Long, // in milliseconds + def compute( + time: Long, // in milliseconds numElements: Long, processingDelay: Long, // in milliseconds schedulingDelay: Long // in milliseconds ): Option[Double] = { - + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") this.synchronized { - if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + if (time > latestTime && numElements > 0 && processingDelay > 0) { // in seconds, should be close to batchDuration val delaySinceUpdate = (time - latestTime).toDouble / 1000 @@ -104,21 +117,30 @@ private[streaming] class PIDRateEstimator( val newRate = (latestRate - proportional * error - integral * historicalError - - derivative * dError).max(0.0) + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + latestTime = time if (firstRun) { latestRate = processingRate latestError = 0D firstRun = false - + logTrace("First run, rate estimation skipped") None } else { latestRate = newRate latestError = error - + logTrace(s"New rate = $newRate") Some(newRate) } - } else None + } else { + logTrace("Rate estimation skipped") + None + } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 17ccebc1ed41..d7210f64fcc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf -import org.apache.spark.SparkException import org.apache.spark.streaming.Duration /** @@ -61,7 +60,8 @@ object RateEstimator { val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) - new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index f702bd5bc946..d33972342731 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -17,9 +17,6 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date - import scala.xml.Node import org.apache.spark.ui.{UIUtils => SparkUIUtils} @@ -36,6 +33,22 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) {SparkUIUtils.tooltip("Time taken to process all jobs of a batch", "top")} } + /** + * Return the first failure reason if finding in the batches. + */ + protected def getFirstFailureReason(batches: Seq[BatchUIData]): Option[String] = { + batches.flatMap(_.outputOperations.flatMap(_._2.failureReason)).headOption + } + + protected def getFirstFailureTableCell(batch: BatchUIData): Seq[Node] = { + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + firstFailureReason.map { failureReason => + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan = 1, includeFirstLineInExpandDetails = false) + }.getOrElse(-) + } + protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) @@ -46,7 +59,8 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") val batchTimeId = s"batch-$batchTime" - + {formattedBatchTime} @@ -75,6 +89,19 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) batchTable } + protected def createOutputOperationProgressBar(batch: BatchUIData): Seq[Node] = { + + { + SparkUIUtils.makeProgressBar( + started = batch.numActiveOutputOp, + completed = batch.numCompletedOutputOp, + failed = batch.numFailedOutputOp, + skipped = 0, + total = batch.outputOperations.size) + } + + } + /** * Return HTML for all rows of this table. */ @@ -86,7 +113,18 @@ private[ui] class ActiveBatchTable( waitingBatches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { - override protected def columns: Seq[Node] = super.columns ++ Status + private val firstFailureReason = getFirstFailureReason(runningBatches) + + override protected def columns: Seq[Node] = super.columns ++ { + Output Ops: Succeeded/Total + Status ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } + } override protected def renderRows: Seq[Node] = { // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display @@ -96,20 +134,42 @@ private[ui] class ActiveBatchTable( } private def runningBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ processing + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing ++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ queued + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued++ { + if (firstFailureReason.nonEmpty) { + // Waiting batches have not run yet, so must have no failure reasons. + - + } else { + Nil + } + } } } private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("completed-batches-table", batchInterval) { - override protected def columns: Seq[Node] = super.columns ++ - Total Delay - {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} + private val firstFailureReason = getFirstFailureReason(batches) + + override protected def columns: Seq[Node] = super.columns ++ { + Total Delay {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} + Output Ops: Succeeded/Total ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } + } override protected def renderRows: Seq[Node] = { batches.flatMap(batch => {completedBatchRow(batch)}) @@ -118,9 +178,17 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: private def completedBatchRow(batch: BatchUIData): Seq[Node] = { val totalDelay = batch.totalDelay val formattedTotalDelay = totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") - baseRow(batch) ++ + + baseRow(batch) ++ { {formattedTotalDelay} + } ++ createOutputOperationProgressBar(batch)++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 0c891662c264..bc1711930d3a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -19,16 +19,16 @@ package org.apache.spark.streaming.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text, Unparsed} +import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.streaming.Time -import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} -import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} +import org.apache.spark.streaming.ui.StreamingJobProgressListener.{OutputOpId, SparkJobId} import org.apache.spark.ui.jobs.UIData.JobUIData +import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} -private case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private val streamingListener = parent.listener @@ -38,6 +38,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Output Op Id Description Duration + Status Job Id Duration Stages: Succeeded/Total @@ -46,27 +47,49 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, sparkJob: SparkJobIdWithUIData): Seq[Node] = { if (sparkJob.jobUIData.isDefined) { - generateNormalJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) } else { - generateDroppedJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) } } + private def generateOutputOpRowWithoutSparkJobs( + outputOpData: OutputOperationUIData, + outputOpDescription: Seq[Node], + formattedOutputOpDuration: String): Seq[Node] = { + + {outputOpData.id.toString} + {outputOpDescription} + {formattedOutputOpDuration} + {outputOpStatusCell(outputOpData, rowspan = 1)} + + - + + - + + - + + - + + - + + } + /** * Generate a row for a Spark Job. Because duplicated output op infos needs to be collapsed into * one cell, we use "rowspan" for the first row of a output op. */ private def generateNormalJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, @@ -90,11 +113,12 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} - {formattedOutputOpDuration} + {formattedOutputOpDuration} ++ + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -125,7 +149,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { total = sparkJob.numTasks - sparkJob.numSkippedTasks) } - {failureReasonCell(lastFailureReason)} + {UIUtils.failureReasonCell(lastFailureReason)} } @@ -134,7 +158,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * with "-" cells. */ private def generateDroppedJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, @@ -145,9 +169,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} - {formattedOutputOpDuration} + {formattedOutputOpDuration} ++ + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -156,7 +181,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { {prefixCells} - {jobId.toString} + {if (jobId >= 0) jobId.toString else "-"} - @@ -170,78 +195,54 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( - outputOpId: OutputOpId, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - // We don't count the durations of dropped jobs - val sparkJobDurations = sparkJobs.filter(_.jobUIData.nonEmpty).map(_.jobUIData.get). - map(sparkJob => { - sparkJob.submissionTime.map { start => - val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) - end - start - } - }) + outputOpData: OutputOperationUIData, + sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = - if (sparkJobDurations.isEmpty || sparkJobDurations.exists(_ == None)) { - // If no job or any job does not finish, set "formattedOutputOpDuration" to "-" + if (outputOpData.duration.isEmpty) { "-" } else { - SparkUIUtils.formatDuration(sparkJobDurations.flatMap(x => x).sum) + SparkUIUtils.formatDuration(outputOpData.duration.get) } - val description = generateOutputOpDescription(sparkJobs) + val description = generateOutputOpDescription(outputOpData) - generateJobRow( - outputOpId, description, formattedOutputOpDuration, sparkJobs.size, true, sparkJobs.head) ++ - sparkJobs.tail.map { sparkJob => + if (sparkJobs.isEmpty) { + generateOutputOpRowWithoutSparkJobs(outputOpData, description, formattedOutputOpDuration) + } else { + val firstRow = generateJobRow( - outputOpId, description, formattedOutputOpDuration, sparkJobs.size, false, sparkJob) - }.flatMap(x => x) - } - - private def generateOutputOpDescription(sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - val lastStageInfo = - sparkJobs.flatMap(_.jobUIData).headOption. // Get the first JobUIData - flatMap { sparkJob => // For the first job, get the latest Stage info - if (sparkJob.stageIds.isEmpty) { - None - } else { - sparkListener.stageIdToInfo.get(sparkJob.stageIds.max) - } + outputOpData, + description, + formattedOutputOpDuration, + sparkJobs.size, + true, + sparkJobs.head) + val tailRows = + sparkJobs.tail.map { sparkJob => + generateJobRow( + outputOpData, + description, + formattedOutputOpDuration, + sparkJobs.size, + false, + sparkJob) } - val lastStageData = lastStageInfo.flatMap { s => - sparkListener.stageIdToData.get((s.stageId, s.attemptId)) + (firstRow ++ tailRows).flatten } - - val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") - - - {lastStageDescription} - ++ Text(lastStageName) } - private def failureReasonCell(failureReason: String): Seq[Node] = { - val isMultiline = failureReason.indexOf('\n') >= 0 - // Display the first line by default - val failureReasonSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - failureReason.substring(0, failureReason.indexOf('\n')) - } else { - failureReason - }) - val details = if (isMultiline) { - // scalastyle:off - - +details - ++ - - // scalastyle:on - } else { - "" - } - {failureReasonSummary}{details} + private def generateOutputOpDescription(outputOp: OutputOperationUIData): Seq[Node] = { +
        + {outputOp.name} + + +details + + +
        } private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { @@ -252,20 +253,37 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } + private def generateOutputOperationStatusForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.SparkException")) { + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.size + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + /** * Generate the job table for the batch. */ private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { - val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId).toSeq. - sortBy(_._1). // sorted by OutputOpId + val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) } + + val outputOps: Seq[(OutputOperationUIData, Seq[SparkJobId])] = + batchUIData.outputOperations.map { case (outputOpId, outputOperation) => + val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) + (outputOperation, sparkJobIds) + }.toSeq.sortBy(_._1.id) sparkListener.synchronized { - val outputOpIdWithJobs: Seq[(OutputOpId, Seq[SparkJobIdWithUIData])] = - outputOpIdToSparkJobIds.map { case (outputOpId, sparkJobIds) => - (outputOpId, + val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => + (outputOpData, sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) } @@ -275,8 +293,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { - outputOpIdWithJobs.map { - case (outputOpId, sparkJobIds) => generateOutputOpIdRow(outputOpId, sparkJobIds) + outputOpWithJobs.map { case (outputOpData, sparkJobIds) => + generateOutputOpIdRow(outputOpData, sparkJobIds) } } @@ -284,7 +302,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } - def render(request: HttpServletRequest): Seq[Node] = { + def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } @@ -337,20 +355,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { - val jobTable = - if (batchUIData.outputOpIdSparkJobIdPairs.isEmpty) { -
        Cannot find any job for Batch {formattedBatchTime}.
        - } else { - generateJobTable(batchUIData) - } - - val content = summary ++ jobTable + val content = summary ++ generateJobTable(batchUIData) SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { - +
        @@ -377,4 +388,19 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
        ")) } + + private def outputOpStatusCell(outputOp: OutputOperationUIData, rowspan: Int): Seq[Node] = { + outputOp.failureReason match { + case Some(failureReason) => + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) + case None => + if (outputOp.endTime.isEmpty) { + + } else { + + } + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index ae508c0e9577..3ef3689de1c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable + import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} +import org.apache.spark.streaming.scheduler.{BatchInfo, OutputOperationInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) @@ -30,6 +32,7 @@ private[ui] case class BatchUIData( val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], + val outputOperations: mutable.HashMap[OutputOpId, OutputOperationUIData] = mutable.HashMap(), var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { /** @@ -59,17 +62,75 @@ private[ui] case class BatchUIData( * The number of recorders received by the receivers in this batch. */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum + + /** + * Update an output operation information of this batch. + */ + def updateOutputOperationInfo(outputOperationInfo: OutputOperationInfo): Unit = { + assert(batchTime == outputOperationInfo.batchTime) + outputOperations(outputOperationInfo.id) = OutputOperationUIData(outputOperationInfo) + } + + /** + * Return the number of failed output operations. + */ + def numFailedOutputOp: Int = outputOperations.values.count(_.failureReason.nonEmpty) + + /** + * Return the number of running output operations. + */ + def numActiveOutputOp: Int = outputOperations.values.count(_.endTime.isEmpty) + + /** + * Return the number of completed output operations. + */ + def numCompletedOutputOp: Int = outputOperations.values.count { + op => op.failureReason.isEmpty && op.endTime.nonEmpty + } + + /** + * Return if this batch has any output operations + */ + def isFailed: Boolean = numFailedOutputOp != 0 } private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { + val outputOperations = mutable.HashMap[OutputOpId, OutputOperationUIData]() + outputOperations ++= batchInfo.outputOperationInfos.mapValues(OutputOperationUIData.apply) new BatchUIData( batchInfo.batchTime, batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, - batchInfo.processingEndTime + batchInfo.processingEndTime, + outputOperations + ) + } +} + +private[ui] case class OutputOperationUIData( + id: OutputOpId, + name: String, + description: String, + startTime: Option[Long], + endTime: Option[Long], + failureReason: Option[String]) { + + def duration: Option[Long] = for (s <- startTime; e <- endTime) yield e - s +} + +private[ui] object OutputOperationUIData { + + def apply(outputOperationInfo: OutputOperationInfo): OutputOperationUIData = { + OutputOperationUIData( + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description, + outputOperationInfo.startTime, + outputOperationInfo.endTime, + outputOperationInfo.failureReason ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index b77c555c68b8..f6cc6edf2569 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -119,6 +119,20 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = synchronized { + // This method is called after onBatchStarted + runningBatchUIData(outputOperationStarted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationStarted.outputOperationInfo) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = synchronized { + // This method is called before onBatchCompleted + runningBatchUIData(outputOperationCompleted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationCompleted.outputOperationInfo) + } + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) @@ -148,6 +162,14 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) receiverInfos.size } + def numActiveReceivers: Int = synchronized { + receiverInfos.count(_._2.active) + } + + def numInactiveReceivers: Int = { + ssc.graph.getReceiverInputStreams().size - numActiveReceivers + } + def numTotalCompletedBatches: Long = synchronized { totalCompletedBatches } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 87af902428ec..b3692c3ea302 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -303,6 +303,7 @@ private[ui] class StreamingPage(parent: StreamingTab) val numCompletedBatches = listener.retainedCompletedBatches.size val numActiveBatches = batchTimes.length - numCompletedBatches + val numReceivers = listener.numInactiveReceivers + listener.numActiveReceivers val table = // scalastyle:off
        Input-Succeeded
        @@ -330,6 +331,11 @@ private[ui] class StreamingPage(parent: StreamingTab) } } + { + if (numReceivers > 0) { +
        Receivers: {listener.numActiveReceivers} / {numReceivers} active
        + } + }
        Avg: {eventRateForAllStreams.formattedAvg} events/sec
        @@ -386,8 +392,15 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + val maxYCalculated = listener.receivedEventRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } } + .reduceOption[Double](math.max) + .map(_.ceil.toLong) + .getOrElse(0L) + + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, eventRates) => + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off @@ -396,7 +409,7 @@ private[ui] class StreamingPage(parent: StreamingTab) - + @@ -424,7 +437,11 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverActive = receiverInfo.map { info => if (info.active) "ACTIVE" else "INACTIVE" }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLocation = receiverInfo.map { info => + val executorId = if (info.executorId.isEmpty) emptyCell else info.executorId + val location = if (info.location.isEmpty) emptyCell else info.location + s"$executorId / $location" + }.getOrElse(emptyCell) val receiverLastError = receiverInfo.map { info => val msg = s"${info.lastErrorMessage} - ${info.lastError}" if (msg.size > 100) msg.take(97) + "..." else msg diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 86cfb1fa4737..d89f7ad3e16b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,6 +17,10 @@ package org.apache.spark.streaming.ui +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + import java.text.SimpleDateFormat import java.util.TimeZone import java.util.concurrent.TimeUnit @@ -124,4 +128,60 @@ private[streaming] object UIUtils { } } } + + def createOutputOperationFailureForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.Spark")) { + // SparkException or SparkDriverExecutionException + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.size + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + + def failureReasonCell( + failureReason: String, + rowspan: Int = 1, + includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val failureDetails = + if (isMultiline && !includeFirstLineInExpandDetails) { + // Skip the first line + failureReason.substring(failureReason.indexOf('\n') + 1) + } else { + failureReason + } + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + if (rowspan == 1) { + + } else { + + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala new file mode 100644 index 000000000000..b2cd524f28b7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation + * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle + * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned + * after the write will contain the batch of records rather than individual records. + * + * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp + * of the latest record in the batch. This is very important in achieving correctness. Consider the + * following example: + * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive + * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding + * 5 and 7. + * + * This means the caller can assume the same write semantics as any other WriteAheadLog + * implementation despite the batching in the background - when the write() returns, the data is + * written to the WAL and is durable. To take advantage of the batching, the caller can write from + * multiple threads, each of which will stay blocked until the corresponding data has been written. + * + * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog. + */ +private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf) + extends WriteAheadLog with Logging { + + import BatchedWriteAheadLog._ + + private val walWriteQueue = new LinkedBlockingQueue[Record]() + + // Whether the writer thread is active + @volatile private var active: Boolean = true + private val buffer = new ArrayBuffer[Record]() + + private val batchedWriterThread = startBatchedWriterThread() + + /** + * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks + * until the record is properly written by the parent. + */ + override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + val promise = Promise[WriteAheadLogRecordHandle]() + val putSuccessfully = synchronized { + if (active) { + walWriteQueue.offer(Record(byteBuffer, time, promise)) + true + } else { + false + } + } + if (putSuccessfully) { + Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + } else { + throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + + s"write request with time $time could be fulfilled.") + } + } + + /** + * This method is not supported as the resulting ByteBuffer would actually require de-aggregation. + * This method is primarily used in testing, and to ensure that it is not used in production, + * we throw an UnsupportedOperationException. + */ + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " + + "as the data may require de-aggregation.") + } + + /** + * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog + * will be de-aggregated. + */ + override def readAll(): JIterator[ByteBuffer] = { + wrappedLog.readAll().asScala.flatMap(deaggregate).asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * This method is handled by the parent WriteAheadLog. + */ + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wrappedLog.clean(threshTime, waitForCompletion) + } + + + /** + * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL. + */ + override def close(): Unit = { + logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") + synchronized { + active = false + } + batchedWriterThread.interrupt() + batchedWriterThread.join() + while (!walWriteQueue.isEmpty) { + val Record(_, time, promise) = walWriteQueue.poll() + promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " + + s"before write request with time $time could be fulfilled.")) + } + wrappedLog.close() + } + + /** Start the actual log writer on a separate thread. */ + private def startBatchedWriterThread(): Thread = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + while (active) { + try { + flushRecords() + } catch { + case NonFatal(e) => + logWarning("Encountered exception in Batched Writer Thread.", e) + } + } + logInfo("BatchedWriteAheadLog Writer thread exiting.") + } + }, "BatchedWriteAheadLog Writer") + thread.setDaemon(true) + thread.start() + thread + } + + /** Write all the records in the buffer to the write ahead log. */ + private def flushRecords(): Unit = { + try { + buffer.append(walWriteQueue.take()) + val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1 + logDebug(s"Received $numBatched records from queue") + } catch { + case _: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.") + } + try { + var segment: WriteAheadLogRecordHandle = null + if (buffer.length > 0) { + logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // threads may not be able to add items in order by time + val sortedByTime = buffer.sortBy(_.time) + // We take the latest record for the timestamp. Please refer to the class Javadoc for + // detailed explanation + val time = sortedByTime.last.time + segment = wrappedLog.write(aggregate(sortedByTime), time) + } + buffer.foreach(_.promise.success(segment)) + } catch { + case e: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.", e) + buffer.foreach(_.promise.failure(e)) + case NonFatal(e) => + logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e) + buffer.foreach(_.promise.failure(e)) + } finally { + buffer.clear() + } + } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() +} + +/** Static methods for aggregating and de-aggregating records. */ +private[util] object BatchedWriteAheadLog { + + /** + * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled + * with the timestamp for the write request of the record, and the promise that will block the + * write request, while a separate thread is actually performing the write. + */ + case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) + + /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ + def aggregate(records: Seq[Record]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) + } + + /** + * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. + * A stream may not have used batching initially, but started using it after a restart. This + * method therefore needs to be backwards compatible. + */ + def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() + try { + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) + } catch { + case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) + Array(buffer) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index fe6328b1ce72..b946e0d8e927 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,9 +17,12 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.collection.parallel.ThreadPoolTaskSupport import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -31,9 +34,10 @@ import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. + * + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. * * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. @@ -46,7 +50,8 @@ private[streaming] class FileBasedWriteAheadLog( logDirectory: String, hadoopConf: Configuration, rollingIntervalSecs: Int, - maxFailures: Int + maxFailures: Int, + closeFileAfterWrite: Boolean ) extends WriteAheadLog with Logging { import FileBasedWriteAheadLog._ @@ -55,8 +60,8 @@ private[streaming] class FileBasedWriteAheadLog( private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") private val threadpoolName = s"WriteAheadLogManager $callerNameTag" - implicit private val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) + private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20) + private val executionContext = ExecutionContext.fromExecutorService(threadpool) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None @@ -79,6 +84,9 @@ private[streaming] class FileBasedWriteAheadLog( while (!succeeded && failures < maxFailures) { try { fileSegment = getLogWriter(time).write(byteBuffer) + if (closeFileAfterWrite) { + resetWriter() + } succeeded = true } catch { case ex: Exception => @@ -118,15 +126,20 @@ private[streaming] class FileBasedWriteAheadLog( * hence the implementation is kept simple. */ def readAll(): JIterator[ByteBuffer] = synchronized { - import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath - logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) - - logFilesToRead.iterator.map { file => + logInfo("Reading from the logs:\n" + logFilesToRead.mkString("\n")) + def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - } flatMap { x => x } + } + if (!closeFileAfterWrite) { + logFilesToRead.iterator.map(readFile).flatten.asJava + } else { + // For performance gains, it makes sense to parallelize the recovery if + // closeFileAfterWrite = true + seqToParIterator(threadpool, logFilesToRead, readFile).asJava + } } /** @@ -142,30 +155,39 @@ private[streaming] class FileBasedWriteAheadLog( * asynchronously. */ def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { - val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } + val oldLogFiles = synchronized { + val expiredLogs = pastLogs.filter { _.endTime < threshTime } + pastLogs --= expiredLogs + expiredLogs + } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") - def deleteFiles() { - oldLogFiles.foreach { logInfo => - try { - val path = new Path(logInfo.path) - val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) - fs.delete(path, true) - synchronized { pastLogs -= logInfo } - logDebug(s"Cleared log file $logInfo") - } catch { - case ex: Exception => - logWarning(s"Error clearing write ahead log file $logInfo", ex) - } + def deleteFile(walInfo: LogInfo): Unit = { + try { + val path = new Path(walInfo.path) + val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) + fs.delete(path, true) + logDebug(s"Cleared log file $walInfo") + } catch { + case ex: Exception => + logWarning(s"Error clearing write ahead log file $walInfo", ex) } logInfo(s"Cleared log files in $logDirectory older than $threshTime") } - if (!executionContext.isShutdown) { - val f = Future { deleteFiles() } - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + oldLogFiles.foreach { logInfo => + if (!executionContext.isShutdown) { + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) + } } } } @@ -231,7 +253,7 @@ private[streaming] object FileBasedWriteAheadLog { def getCallerName(): Option[String] = { val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split("\\.").lastOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ @@ -247,4 +269,23 @@ private[streaming] object FileBasedWriteAheadLog { } }.sortBy { _.startTime } } + + /** + * This creates an iterator from a parallel collection, by keeping at most `n` objects in memory + * at any given time, where `n` is the size of the thread pool. This is crucial for use cases + * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to + * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize. + */ + def seqToParIterator[I, O]( + tpool: ThreadPoolExecutor, + source: Seq[I], + handler: I => Iterator[O]): Iterator[O] = { + val taskSupport = new ThreadPoolTaskSupport(tpool) + val groupSize = tpool.getMaximumPoolSize.max(8) + source.grouped(groupSize).flatMap { group => + val parallelCollection = group.par + parallelCollection.tasksupport = taskSupport + parallelCollection.map(handler) + }.flatten + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala index f7168229ec15..56d4977da0b5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -30,7 +30,7 @@ private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: extends Closeable { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { assertOpen() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index c3bb59f3fef9..a375c0729534 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.util -import java.io.{Closeable, EOFException} +import java.io.{IOException, Closeable, EOFException} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration @@ -32,7 +32,7 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream private var nextItem: Option[ByteBuffer] = None override def hasNext: Boolean = synchronized { @@ -55,6 +55,19 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config logDebug("Error reading next item, EOF reached", e) close() false + case e: IOException => + logWarning("Error while trying to read data. If the file was deleted, " + + "this should be okay.", e) + close() + if (HdfsUtils.checkFileExists(path, conf)) { + // If file exists, this could be a legitimate error + throw e + } else { + // File was deleted. This can occur when the daemon cleanup thread takes time to + // delete the file during recovery. + false + } + case e: Exception => logWarning("Error while trying to read data from HDFS.", e) close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a45..1185f30265f6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -24,6 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.spark.util.Utils + /** * A writer for writing byte-buffers to a write ahead log file. */ @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index f60688f173c4..13a765d035ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.streaming.util +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ @@ -42,8 +44,19 @@ private[streaming] object HdfsUtils { def getInputStream(path: String, conf: Configuration): FSDataInputStream = { val dfsPath = new Path(path) val dfs = getFileSystemForPath(dfsPath, conf) - val instream = dfs.open(dfsPath) - instream + if (dfs.isFile(dfsPath)) { + try { + dfs.open(dfsPath) + } catch { + case e: IOException => + // If we are really unlucky, the file may be deleted as we're opening the stream. + // This can happen as clean up is performed by daemon threads that may be left over from + // previous runs. + if (!dfs.isFile(dfsPath)) null else throw e + } + } else { + null + } } def checkState(state: Boolean, errorMsg: => String) { @@ -71,4 +84,11 @@ private[streaming] object HdfsUtils { case _ => fs } } + + /** Check if the file exists at the given path. */ + def checkFileExists(path: String, conf: Configuration): Boolean = { + val hdpPath = new Path(path) + val fs = getFileSystemForPath(hdpPath, conf) + fs.isFile(hdpPath) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index dd32ad5ad811..bfb53614050a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,8 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * interruptTimer = true will interrupt the callback - * if it is in progress (not guaranteed to give correct time in this case). + * + * @param interruptTimer True will interrupt the callback if it is in progress (not guaranteed to + * give correct time in this case). False guarantees that there will be at + * least one callback after `stop` has been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { @@ -87,18 +89,23 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: prevTime } + private def triggerActionForNextInterval(): Unit = { + clock.waitTillTime(nextTime) + callback(nextTime) + prevTime = nextTime + nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) + } + /** * Repeatedly call the callback every interval. */ private def loop() { try { while (!stopped) { - clock.waitTillTime(nextTime) - callback(nextTime) - prevTime = nextTime - nextTime += period - logDebug("Callback for " + name + " called at time " + prevTime) + triggerActionForNextInterval() } + triggerActionForNextInterval() } catch { case e: InterruptedException => } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 000000000000..3f139ad138c8 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the given threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) + } +} + +/** Implementation of StateMap interface representing an empty map */ +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = { + throw new NotImplementedError("put() should not be called on an EmptyStateMap") + } + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = this + override def remove(key: K): Unit = { } + override def toDebugString(): String = "" +} + +/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile var parentStateMap: StateMap[K, S], + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) = this( + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) + + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + if (!stateInfo.deleted) { + Some(stateInfo.data) + } else { + None + } + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) => + !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime + }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + /** Whether the delta chain lenght is long enough that it should be compacted */ + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + /** Length of the delta chains of this map */ + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + /** + * Approximate number of keys in the map. This is an overestimation that is mainly used to + * reserve capacity in a new map at delta compaction time. + */ + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + /** Get all the data of this map as string formatted as a tree based on the delta depth */ + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) + "+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + override def toString(): String = { + s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" + } + + /** + * Serialize the map data. Besides serialization, this method actually compact the deltas + * (if needed) in a single pass over all the data in the map. + */ + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + + // Write the data in the delta of this state map + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the data in the parent state map while copying the data into a new parent map for + // compaction (if needed) + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + // First write the approximate size of the data to be written, so that readObject can + // allocate appropriately sized OpenHashMap. + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + + // Write the final limit marking object with the correct count of records written. + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + /** Deserialize the map data. */ + private def readObject(inputStream: ObjectInputStream): Unit = { + + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + + // Read the data of the delta + val deltaMapSize = inputStream.readInt() + deltaMap = if (deltaMapSize != 0) { + new OpenHashMap[K, StateInfo[S]](deltaMapSize) + } else { + new OpenHashMap[K, StateInfo[S]](initialCapacity) + } + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + + // Read the data of the parent map. Keep reading records, until the limiter is reached + // First read the approximate number of records to expect and allocate properly size + // OpenHashMap + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) + + // Read the records until the limit marking object has been reached + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } +} + +/** + * Companion object of [[OpenHashMapBasedStateMap]] having associated helper + * classes and methods + */ +private[streaming] object OpenHashMapBasedStateMap { + + /** Internal class to represent the state information */ + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + /** + * Internal class to represent a marker the demarkate the the end of all state data in the + * serialized bytes. + */ + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 7f6ff12c58d4..7f9e2c973497 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -31,11 +31,17 @@ private[streaming] object WriteAheadLogUtils extends Logging { val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs" val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures" + val RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = + "spark.streaming.receiver.writeAheadLog.closeFileAfterWrite" val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class" val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching" + val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout" + val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = + "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" val DEFAULT_ROLLING_INTERVAL_SECS = 60 val DEFAULT_MAX_FAILURES = 3 @@ -60,6 +66,26 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = true) + } + + /** + * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records + * before we fail the write attempt to unblock receivers. + */ + def getBatchingTimeout(conf: SparkConf): Long = { + conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000) + } + + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { + if (isDriver) { + conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) + } else { + conf.getBoolean(RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) + } + } + /** * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. @@ -103,7 +129,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } else { sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) } - classNameOption.map { className => + val wal = classNameOption.map { className => try { instantiateClass( Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) @@ -113,7 +139,13 @@ private[streaming] object WriteAheadLogUtils extends Logging { } }.getOrElse { new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf, - getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver)) + getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), + shouldCloseFileAfterWrite(sparkConf, isDriver)) + } + if (isBatchingEnabled(sparkConf, isDriver)) { + new BatchedWriteAheadLog(wal, sparkConf) + } else { + wal } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13..9722c60bba1c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,28 +18,28 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -47,21 +47,20 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +73,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +117,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +179,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +188,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +242,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +261,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +285,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +346,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +377,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +409,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +434,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +510,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +551,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +564,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +584,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +687,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +704,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +730,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +760,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -772,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { @@ -782,39 +817,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +894,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +912,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +941,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +971,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +1004,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1049,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1065,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1089,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1110,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1146,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1171,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1205,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1228,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1255,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1273,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1297,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1313,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1297,20 +1332,20 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + List> initial = Arrays.asList( + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1356,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1376,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1405,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1421,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1434,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1463,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1479,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1500,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1522,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1537,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1554,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1580,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1597,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1655,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1699,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1748,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1787,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1801,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1836,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1854,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1903,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1912,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1925,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java new file mode 100644 index 000000000000..bc4bc2eb4223 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.util.ManualClock; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function3; +import org.apache.spark.api.java.function.Function4; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; + +public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable { + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + final Function4, State, Optional> + mappingFunc = + new Function4, State, Optional>() { + + @Override + public Optional call( + Time time, String word, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + } + }; + + JavaMapWithStateDStream stateDstream = + wordsDstream.mapWithState( + StateSpec.function(mappingFunc) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream stateSnapshots = stateDstream.stateSnapshots(); + + final Function3, State, Double> mappingFunc2 = + new Function3, State, Double>() { + + @Override + public Double call(String key, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + } + }; + + JavaMapWithStateDStream stateDstream2 = + wordsDstream.mapWithState( + StateSpec.function(mappingFunc2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream stateSnapshots2 = stateDstream2.stateSnapshots(); + } + + @Test + public void testBasicFunction() { + List> inputData = Arrays.asList( + Collections.emptyList(), + Arrays.asList("a"), + Arrays.asList("a", "b"), + Arrays.asList("a", "b", "c"), + Arrays.asList("a", "b"), + Arrays.asList("a"), + Collections.emptyList() + ); + + List> outputData = Arrays.asList( + Collections.emptySet(), + Sets.newHashSet(1), + Sets.newHashSet(2, 1), + Sets.newHashSet(3, 2, 1), + Sets.newHashSet(4, 3), + Sets.newHashSet(5), + Collections.emptySet() + ); + + List>> stateData = Arrays.asList( + Collections.>emptySet(), + Sets.newHashSet(new Tuple2("a", 1)), + Sets.newHashSet(new Tuple2("a", 2), new Tuple2("b", 1)), + Sets.newHashSet( + new Tuple2("a", 3), + new Tuple2("b", 2), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 4), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)) + ); + + Function3, State, Integer> mappingFunc = + new Function3, State, Integer>() { + + @Override + public Integer call(String key, Optional value, State state) throws Exception { + int sum = value.or(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; + } + }; + testOperation( + inputData, + StateSpec.function(mappingFunc), + outputData, + stateData); + } + + private void testOperation( + List> input, + StateSpec mapWithStateSpec, + List> expectedOutputs, + List>> expectedStateSnapshots) { + int numBatches = expectedOutputs.size(); + JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); + JavaMapWithStateDStream mapWithStateDStream = + JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { + @Override + public Tuple2 call(K x) throws Exception { + return new Tuple2(x, 1); + } + })).mapWithState(mapWithStateSpec); + + final List> collectedOutputs = + Collections.synchronizedList(Lists.>newArrayList()); + mapWithStateDStream.foreachRDD(new Function, Void>() { + @Override + public Void call(JavaRDD rdd) throws Exception { + collectedOutputs.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + final List>> collectedStateSnapshots = + Collections.synchronizedList(Lists.>>newArrayList()); + mapWithStateDStream.stateSnapshots().foreachRDD(new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + BatchCounter batchCounter = new BatchCounter(ssc.ssc()); + ssc.start(); + ((ManualClock) ssc.ssc().scheduler().clock()) + .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); + batchCounter.waitUntilBatchesCompleted(numBatches, 10000); + + Assert.assertEquals(expectedOutputs, collectedOutputs); + Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69de..7a8ef9d14784 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,7 +37,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +64,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +83,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +95,54 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); + } + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java new file mode 100644 index 000000000000..67b2a0703e02 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.*; + +public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + + @Override + public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { + JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { + JavaReceiverInfo receiverInfo = receiverError.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopped) { + JavaReceiverInfo receiverInfo = receiverStopped.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onBatchSubmitted(JavaStreamingListenerBatchSubmitted batchSubmitted) { + super.onBatchSubmitted(batchSubmitted); + } + + @Override + public void onBatchStarted(JavaStreamingListenerBatchStarted batchStarted) { + super.onBatchStarted(batchStarted); + } + + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + super.onBatchCompleted(batchCompleted); + } + + @Override + public void onOutputOperationStarted(JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + super.onOutputOperationStarted(outputOperationStarted); + } + + @Override + public void onOutputOperationCompleted(JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + super.onOutputOperationCompleted(outputOperationCompleted); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index bb80bff6dc2e..57b50bdfd652 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -17,16 +17,13 @@ package org.apache.spark.streaming -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.util.{List => JList} -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -39,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) + val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] @@ -72,9 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out + res.map(_.asJava).asJava } /** @@ -90,12 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[JList[V]]]() - res.map{entry => - val lists = entry.map(new ArrayList[V](_)) - out.append(new ArrayList[JList[V]](lists)) - } - out + res.map(entry => entry.map(_.asJava).asJava).asJava } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 50e8f9fc159c..f02fa87f6194 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -20,11 +20,13 @@ import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; -import java.util.Collection; +import java.util.Iterator; +import java.util.List; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.collections.Transformer; +import com.google.common.base.Function; +import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -32,40 +34,40 @@ import org.junit.Test; import org.junit.Assert; -class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { - int index = -1; - public JavaWriteAheadLogSuiteHandle(int idx) { - index = idx; - } -} - public class JavaWriteAheadLogSuite extends WriteAheadLog { - class Record { + static class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { + int index = -1; + JavaWriteAheadLogSuiteHandle(int idx) { + index = idx; + } + } + + static class Record { long time; int index; ByteBuffer buffer; - public Record(long tym, int idx, ByteBuffer buf) { + Record(long tym, int idx, ByteBuffer buf) { index = idx; time = tym; buffer = buf; } } private int index = -1; - private ArrayList records = new ArrayList(); + private final List records = new ArrayList<>(); // Methods for WriteAheadLog @Override - public WriteAheadLogRecordHandle write(java.nio.ByteBuffer record, long time) { + public WriteAheadLogRecordHandle write(ByteBuffer record, long time) { index += 1; - records.add(new org.apache.spark.streaming.JavaWriteAheadLogSuite.Record(time, index, record)); + records.add(new Record(time, index, record)); return new JavaWriteAheadLogSuiteHandle(index); } @Override - public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { + public ByteBuffer read(WriteAheadLogRecordHandle handle) { if (handle instanceof JavaWriteAheadLogSuiteHandle) { int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index; for (Record record: records) { @@ -78,14 +80,13 @@ public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { } @Override - public java.util.Iterator readAll() { - Collection buffers = CollectionUtils.collect(records, new Transformer() { + public Iterator readAll() { + return Iterators.transform(records.iterator(), new Function() { @Override - public Object transform(Object input) { - return ((Record) input).buffer; + public ByteBuffer apply(Record input) { + return input.buffer; } }); - return buffers.iterator(); } @Override @@ -107,23 +108,24 @@ public void close() { public void testCustomWAL() { SparkConf conf = new SparkConf(); conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false"); WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = wal.write(ByteBuffer.wrap(data1.getBytes()), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertTrue(new String(wal.read(handle).array()).equals(data1)); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes()), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes()), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes()), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); - java.util.Iterator dataIterator = wal.readAll(); - ArrayList readData = new ArrayList(); + Iterator dataIterator = wal.readAll(); + List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array())); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } - Assert.assertTrue(readData.equals(Arrays.asList("data3", "data4"))); + Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala new file mode 100644 index 000000000000..0295e059f7bc --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler._ + +class JavaStreamingListenerWrapperSuite extends SparkFunSuite { + + test("basic") { + val listener = new TestJavaStreamingListener() + val listenerWrapper = new JavaStreamingListenerWrapper(listener) + + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( + streamId = 2, + name = "test", + active = true, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStarted(receiverStarted) + assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) + + val receiverStopped = StreamingListenerReceiverStopped(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStopped(receiverStopped) + assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) + + val receiverError = StreamingListenerReceiverError(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1", + lastErrorMessage = "failed", + lastError = "failed", + lastErrorTime = System.currentTimeMillis() + )) + listenerWrapper.onReceiverError(receiverError) + assertReceiverInfo(listener.receiverError.receiverInfo, receiverError.receiverInfo) + + val batchSubmitted = StreamingListenerBatchSubmitted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + None, + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = None, + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = None, + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchSubmitted(batchSubmitted) + assertBatchInfo(listener.batchSubmitted.batchInfo, batchSubmitted.batchInfo) + + val batchStarted = StreamingListenerBatchStarted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchStarted(batchStarted) + assertBatchInfo(listener.batchStarted.batchInfo, batchStarted.batchInfo) + + val batchCompleted = StreamingListenerBatchCompleted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + Some(1010L), + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = Some(1010L), + failureReason = None)) + )) + listenerWrapper.onBatchCompleted(batchCompleted) + assertBatchInfo(listener.batchCompleted.batchInfo, batchCompleted.batchInfo) + + val outputOperationStarted = StreamingListenerOutputOperationStarted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None + )) + listenerWrapper.onOutputOperationStarted(outputOperationStarted) + assertOutputOperationInfo(listener.outputOperationStarted.outputOperationInfo, + outputOperationStarted.outputOperationInfo) + + val outputOperationCompleted = StreamingListenerOutputOperationCompleted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None + )) + listenerWrapper.onOutputOperationCompleted(outputOperationCompleted) + assertOutputOperationInfo(listener.outputOperationCompleted.outputOperationInfo, + outputOperationCompleted.outputOperationInfo) + } + + private def assertReceiverInfo( + javaReceiverInfo: JavaReceiverInfo, receiverInfo: ReceiverInfo): Unit = { + assert(javaReceiverInfo.streamId === receiverInfo.streamId) + assert(javaReceiverInfo.name === receiverInfo.name) + assert(javaReceiverInfo.active === receiverInfo.active) + assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.executorId === receiverInfo.executorId) + assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) + assert(javaReceiverInfo.lastError === receiverInfo.lastError) + assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) + } + + private def assertBatchInfo(javaBatchInfo: JavaBatchInfo, batchInfo: BatchInfo): Unit = { + assert(javaBatchInfo.batchTime === batchInfo.batchTime) + assert(javaBatchInfo.streamIdToInputInfo.size === batchInfo.streamIdToInputInfo.size) + batchInfo.streamIdToInputInfo.foreach { case (streamId, streamInputInfo) => + assertStreamingInfo(javaBatchInfo.streamIdToInputInfo.get(streamId), streamInputInfo) + } + assert(javaBatchInfo.submissionTime === batchInfo.submissionTime) + assert(javaBatchInfo.processingStartTime === batchInfo.processingStartTime.getOrElse(-1)) + assert(javaBatchInfo.processingEndTime === batchInfo.processingEndTime.getOrElse(-1)) + assert(javaBatchInfo.schedulingDelay === batchInfo.schedulingDelay.getOrElse(-1)) + assert(javaBatchInfo.processingDelay === batchInfo.processingDelay.getOrElse(-1)) + assert(javaBatchInfo.totalDelay === batchInfo.totalDelay.getOrElse(-1)) + assert(javaBatchInfo.numRecords === batchInfo.numRecords) + assert(javaBatchInfo.outputOperationInfos.size === batchInfo.outputOperationInfos.size) + batchInfo.outputOperationInfos.foreach { case (outputOperationId, outputOperationInfo) => + assertOutputOperationInfo( + javaBatchInfo.outputOperationInfos.get(outputOperationId), outputOperationInfo) + } + } + + private def assertStreamingInfo( + javaStreamInputInfo: JavaStreamInputInfo, streamInputInfo: StreamInputInfo): Unit = { + assert(javaStreamInputInfo.inputStreamId === streamInputInfo.inputStreamId) + assert(javaStreamInputInfo.numRecords === streamInputInfo.numRecords) + assert(javaStreamInputInfo.metadata === streamInputInfo.metadata.asJava) + assert(javaStreamInputInfo.metadataDescription === streamInputInfo.metadataDescription.orNull) + } + + private def assertOutputOperationInfo( + javaOutputOperationInfo: JavaOutputOperationInfo, + outputOperationInfo: OutputOperationInfo): Unit = { + assert(javaOutputOperationInfo.batchTime === outputOperationInfo.batchTime) + assert(javaOutputOperationInfo.id === outputOperationInfo.id) + assert(javaOutputOperationInfo.name === outputOperationInfo.name) + assert(javaOutputOperationInfo.description === outputOperationInfo.description) + assert(javaOutputOperationInfo.startTime === outputOperationInfo.startTime.getOrElse(-1)) + assert(javaOutputOperationInfo.endTime === outputOperationInfo.endTime.getOrElse(-1)) + assert(javaOutputOperationInfo.failureReason === outputOperationInfo.failureReason.orNull) + } +} + +class TestJavaStreamingListener extends JavaStreamingListener { + + var receiverStarted: JavaStreamingListenerReceiverStarted = null + var receiverError: JavaStreamingListenerReceiverError = null + var receiverStopped: JavaStreamingListenerReceiverStopped = null + var batchSubmitted: JavaStreamingListenerBatchSubmitted = null + var batchStarted: JavaStreamingListenerBatchStarted = null + var batchCompleted: JavaStreamingListenerBatchCompleted = null + var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null + var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + this.receiverStarted = receiverStarted + } + + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + this.receiverError = receiverError + } + + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + this.receiverStopped = receiverStopped + } + + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + this.batchSubmitted = batchSubmitted + } + + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + this.batchStarted = batchStarted + } + + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + this.batchCompleted = batchCompleted + } + + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + this.outputOperationStarted = outputOperationStarted + } + + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + this.outputOperationCompleted = outputOperationCompleted + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 255376807c95..9d296c6d3ef8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -191,6 +191,20 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("union with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 8, 101 to 108, 201 to 208) + intercept[SparkException] { + testOperation( + input, + (s: DStream[Int]) => s.union(s.map(_ + 4)), + output, + input.length, + false + ) + } + } + test("StreamingContext.union") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -211,6 +225,32 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("transform with NULL") { + val input = Seq(1 to 4) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => null.asInstanceOf[RDD[Int]]), + Seq(Seq()), + 1, + false + ) + } + } + + test("transform with input stream return None") { + val input = Seq(1 to 4, 5 to 8, null) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)), + input.filterNot(_ == null).map(_.map(_.toString)), + input.length, + false + ) + } + } + test("transformWith") { val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) @@ -231,6 +271,27 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData1, inputData2, operation, outputData, true) } + test("transformWith with input stream return None") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), null ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), null ) + val outputData = Seq( + Seq("a", "b", "a", "b"), + Seq("a", "b", "", ""), + Seq("") + ) + + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.transformWith( // RDD.join in transform + s2, + (rdd1: RDD[String], rdd2: RDD[String]) => rdd1.union(rdd2) + ) + } + + intercept[SparkException] { + testOperation(inputData1, inputData2, operation, outputData, inputData1.length, true) + } + } + test("StreamingContext.transform") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -247,6 +308,24 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(input, operation, output) } + test("StreamingContext.transform with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 12, 101 to 112, 201 to 212) + + // transform over 3 DStreams by doing union of the 3 RDDs + val operation = (s: DStream[Int]) => { + s.context.transform( + Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams + (rdds: Seq[RDD[_]], time: Time) => + rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs + ) + } + + intercept[SparkException] { + testOperation(input, operation, output, input.length, false) + } + } + test("cogroup") { val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 67c2d900940a..cd28d3cf408d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.streaming -import java.io.File +import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -29,19 +29,153 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} + +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + private def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.map(_.flatten) + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -54,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -249,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -397,26 +533,28 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint(checkpointDir) - val dstream = new RateLimitInputDStream(ssc) { + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, new ConstantEstimator(200))) } - SingletonTestRateReceiver.reset() val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) output.register() runStreams(ssc, 5, 5) - SingletonTestRateReceiver.reset() ssc = new StreamingContext(checkpointDir) ssc.start() - val outputNew = advanceTimeWithRealDelay(ssc, 2) - eventually(timeout(5.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + } + + advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) } ssc.stop() - ssc = null } // This tests whether file input stream remembers what files were seen before @@ -577,52 +715,57 @@ class CheckpointSuite extends TestSuiteBase { } } + // This tests whether spark can deserialize array object + // refer to SPARK-5569 + test("recovery from checkpoint contains array object") { + // create a class which is invisible to app class loader + val jar = TestUtils.createJarWithClasses( + classNames = Seq("testClz"), + toStringValue = "testStringValue" + ) - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) + // invisible to current class loader + val appClassLoader = getClass.getClassLoader + intercept[ClassNotFoundException](appClassLoader.loadClass("testClz")) + + // visible to mutableURLClassLoader + val loader = new MutableURLClassLoader( + Array(jar), appClassLoader) + assert(loader.loadClass("testClz").newInstance().toString == "testStringValue") + + // create and serialize Array[testClz] + // scalastyle:off classforname + val arrayObj = Class.forName("[LtestClz;", false, loader) + // scalastyle:on classforname + val bos = new ByteArrayOutputStream() + new ObjectOutputStream(bos).writeObject(arrayObj) + + // deserialize the Array[testClz] + val ois = new ObjectInputStreamWithLoader( + new ByteArrayInputStream(bos.toByteArray), loader) + assert(ois.readObject().asInstanceOf[Class[_]].getName == "[LtestClz;") + } - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null + test("SPARK-11267: the race condition of two checkpoints in a batch") { + val jobGenerator = mock(classOf[JobGenerator]) + val checkpointDir = Utils.createTempDir().toString + val checkpointWriter = + new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration()) + val bytes1 = Array.fill[Byte](10)(1) + new checkpointWriter.CheckpointWriteHandler( + Time(2000), bytes1, clearCheckpointDataLater = false).run() + val bytes2 = Array.fill[Byte](10)(2) + new checkpointWriter.CheckpointWriteHandler( + Time(1000), bytes2, clearCheckpointDataLater = true).run() + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir).reverse.map { path => + new File(path.toUri) + } + assert(checkpointFiles.size === 2) + // Although bytes2 was written with an old time, it contains the latest status, so we should + // try to read from it at first. + assert(Files.toByteArray(checkpointFiles(0)) === bytes2) + assert(Files.toByteArray(checkpointFiles(1)) === bytes1) + checkpointWriter.stop() } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 8844c9d74b93..bc223e648a41 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} /** * Tests whether scope information is passed from DStream operations to RDDs correctly. @@ -32,7 +35,9 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd private val batchDuration: Duration = Seconds(1) override def beforeAll(): Unit = { - ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + val conf = new SparkConf().setMaster("local").setAppName("test") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + ssc = new StreamingContext(new SparkContext(conf), batchDuration) } override def afterAll(): Unit = { @@ -103,6 +108,8 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd test("scoping nested operations") { val inputStream = new DummyInputDStream(ssc) + // countByKeyAndWindow internally uses reduceByKeyAndWindow, but only countByKeyAndWindow + // should appear in scope val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) countStream.initialize(Time(0)) @@ -137,6 +144,57 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd testStream(countStream) } + test("transform should allow RDD operations to be captured in scopes") { + val inputStream = new DummyInputDStream(ssc) + val transformedStream = inputStream.transform { _.map { _ -> 1}.reduceByKey(_ + _) } + transformedStream.initialize(Time(0)) + + val transformScopeBase = transformedStream.baseScope.map(RDDOperationScope.fromJson) + val transformScope1 = transformedStream.getOrCompute(Time(1000)).get.scope + val transformScope2 = transformedStream.getOrCompute(Time(2000)).get.scope + val transformScope3 = transformedStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(transformScopeBase, transformScope1, transformScope2, transformScope3) + assert(transformScopeBase.get.name === "transform") + assertNestedScopeCorrect(transformScope1.get, 1000) + assertNestedScopeCorrect(transformScope2.get, 2000) + assertNestedScopeCorrect(transformScope3.get, 3000) + + def assertNestedScopeCorrect(rddScope: RDDOperationScope, batchTime: Long): Unit = { + assert(rddScope.name === "reduceByKey") + assert(rddScope.parent.isDefined) + assertScopeCorrect(transformScopeBase.get, rddScope.parent.get, batchTime) + } + } + + test("foreachRDD should allow RDD operations to be captured in scope") { + val inputStream = new DummyInputDStream(ssc) + val generatedRDDs = new ArrayBuffer[RDD[(Int, Int)]] + inputStream.foreachRDD { rdd => + generatedRDDs += rdd.map { _ -> 1}.reduceByKey(_ + _) + } + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(3000) + batchCounter.waitUntilBatchesCompleted(3, 10000) + assert(generatedRDDs.size === 3) + + val foreachBaseScope = + ssc.graph.getOutputStreams().head.baseScope.map(RDDOperationScope.fromJson) + assertDefined(foreachBaseScope) + assert(foreachBaseScope.get.name === "foreachRDD") + + val rddScopes = generatedRDDs.map { _.scope } + assertDefined(rddScopes: _*) + rddScopes.zipWithIndex.foreach { case (rddScope, idx) => + assert(rddScope.get.name === "reduceByKey") + assert(rddScope.get.parent.isDefined) + assertScopeCorrect(foreachBaseScope.get, rddScope.get.parent.get, (idx + 1) * 1000) + } + } + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ private def assertPropertiesNotSet(): Unit = { assert(ssc != null) @@ -149,19 +207,12 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd baseScope: RDDOperationScope, rddScope: RDDOperationScope, batchTime: Long): Unit = { - assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) - } - - /** Assert that the given RDD scope inherits the base name and ID correctly. */ - private def assertScopeCorrect( - baseScopeId: String, - baseScopeName: String, - rddScope: RDDOperationScope, - batchTime: Long): Unit = { + val (baseScopeId, baseScopeName) = (baseScope.id, baseScope.name) val formattedBatchTime = UIUtils.formatBatchTime( batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) assert(rddScope.id === s"${baseScopeId}_$batchTime") assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + assert(rddScope.parent.isEmpty) // There should not be any higher scope } /** Assert that all the specified options are defined. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 0c4c06534a69..e82c2fa4e72a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -17,25 +17,32 @@ package org.apache.spark.streaming -import org.apache.spark.Logging +import java.io.File + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkFunSuite, Logging} import org.apache.spark.util.Utils /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with Logging { +class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val directory = Utils.createTempDir() - val numBatches = 30 + private val batchDuration: Duration = Milliseconds(1000) + private val numBatches = 30 + private var directory: File = null - override def batchDuration: Duration = Milliseconds(1000) - - override def useManualClock: Boolean = false + before { + directory = Utils.createTempDir() + } - override def afterFunction() { - Utils.deleteRecursively(directory) - super.afterFunction() + after { + if (directory != null) { + Utils.deleteRecursively(directory) + } + StreamingContext.getActive().foreach { _.stop() } } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index ec2852d9a020..3a3176b91b1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -76,6 +76,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { fail("Timeout: cannot finish all batches in 30 seconds") } + // Ensure progress listener has been notified of all events + ssc.scheduler.listenerBus.waitUntilEmpty(500) + // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) assert(ssc.progressListener.numTotalProcessedRecords === input.size) @@ -203,28 +206,28 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false - - // set up the network stream using the test receiver - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.receiverStream[Int](testReceiver) - val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - val outputStream = new TestOutputStream(countStream, outputBuffer) def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Let the data from the receiver be received - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val startTime = System.currentTimeMillis() - while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && - System.currentTimeMillis() - startTime < 5000) { - Thread.sleep(100) - clock.advance(batchDuration.milliseconds) + + // set up the network stream using the test receiver + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val networkStream = ssc.receiverStream[Int](testReceiver) + val countStream = networkStream.count + + val outputStream = new TestOutputStream(countStream, outputBuffer) + outputStream.register() + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while ((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -236,30 +239,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = true") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - val inputIterator = input.toIterator - for (i <- 0 until input.size) { - // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + val inputIterator = input.toIterator + for (i <- 0 until input.size) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -279,33 +282,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = false") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) - // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - - // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + } // Verify whether data received was as expected logInfo("--------------------------------") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala new file mode 100644 index 000000000000..6b21433f1781 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.reflect.ClassTag + +import org.scalatest.PrivateMethodTester._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MapWithStateSuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + checkpointDir = Utils.createTempDir("checkpoint") + } + + after { + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) + } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + + override def beforeAll(): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + test("state - get, exists, update, remove, ") { + var state: StateImpl[Int] = null + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOption.getOrElse(-1) === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOption.getOrElse(-1) === -1) + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + state = new StateImpl[Int]() + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state.remove() + testState(None, shouldBeRemoved = true) + + state.wrapTimingOutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + test("mapWithState - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation[String, Int, Int]( + inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - basic operations with advanced API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and mapped type as Double + val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // mapped type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.mappedClass === classOf[Long]) + } + val ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with mapWithState and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + + test("mapWithState - states as mapped data") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - initial states, with nothing returned as from mapping function") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, mapWithStateSpec, outputData, stateData) + } + + test("mapWithState - state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + Some(key) + } else { + state.update(value.get) + None + } + } + + testOperation( + inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData) + } + + test("mapWithState - state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20) + + // b and c should be returned once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + + test("mapWithState - checkpoint durations") { + val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + + try { + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0 + val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc)) + val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + mapWithStateStream.checkpoint(d) + } + mapWithStateStream.register() + ssc.checkpoint(checkpointDir.toString) + ssc.start() // should initialize all the checkpoint durations + assert(mapWithStateStream.checkpointDuration === null) + assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + ssc.stop(stopSparkContext = false) + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } + + + test("mapWithState - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (key: String, value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + mapWithStateStream.checkpoint(checkpointDuration) + mapWithStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ): Unit = { + require(expectedOutputs.size == expectedStateSnapshots.size) + + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, mapWithStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + + // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec) + val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) + (collectedOutputs, collectedStateSnapshots) + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected $debugString" + ) + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c2..c17fb7238151 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,7 +29,8 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -47,7 +48,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + val conf = new SparkConf() + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + .set("spark.app.id", "streaming-test") val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) @@ -184,7 +187,7 @@ class ReceivedBlockHandlerSuite } test("Test Block - isFullyConsumed") { - val sparkConf = new SparkConf() + val sparkConf = new SparkConf().set("spark.app.id", "streaming-test") sparkConf.set("spark.storage.unrollMemoryThreshold", "512") // spark.storage.unrollFraction set to 0.4 for BlockManager sparkConf.set("spark.storage.unrollFraction", "0.4") @@ -251,12 +254,14 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) - manager.initialize("app-id") - blockManagerBuffer += manager - manager + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(blockManager.memoryStore) + blockManager.initialize("app-id") + blockManagerBuffer += blockManager + blockManager } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index f793a12843b2..081f5a1c93e6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.File +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -32,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} +import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -207,6 +208,75 @@ class ReceivedBlockTrackerSuite tracker1.isWriteAheadLogEnabled should be (false) } + test("parallel file deletion in FileBasedWriteAheadLog is robust to deletion error") { + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + + val addBlocks = generateBlockInfos() + val batch1 = addBlocks.slice(0, 1) + val batch2 = addBlocks.slice(1, 3) + val batch3 = addBlocks.slice(3, addBlocks.length) + + assert(getWriteAheadLogFiles().length === 0) + + // list of timestamps for files + val t = Seq.tabulate(5)(i => i * 1000) + + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 1) + + // The goal is to create several log files which should have been cleaned up. + // If we face any issue during recovery, because these old files exist, then we need to make + // deletion more robust rather than a parallelized operation where we fire and forget + val batch1Allocation = createBatchAllocation(t(1), batch1) + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + + writeEventsManually(getLogFileName(t(2)), Seq(createBatchCleanup(t(1)))) + + val batch2Allocation = createBatchAllocation(t(3), batch2) + writeEventsManually(getLogFileName(t(3)), batch2.map(BlockAdditionEvent) :+ batch2Allocation) + + writeEventsManually(getLogFileName(t(4)), batch3.map(BlockAdditionEvent)) + + // We should have 5 different log files as we called `writeEventsManually` with 5 different + // timestamps + assert(getWriteAheadLogFiles().length === 5) + + // Create the tracker to recover from the log files. We're going to ask the tracker to clean + // things up, and then we're going to rewrite that data, and recover using a different tracker. + // They should have identical data no matter what + val tracker = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + + def compareTrackers(base: ReceivedBlockTracker, subject: ReceivedBlockTracker): Unit = { + subject.getBlocksOfBatchAndStream(t(3), streamId) should be( + base.getBlocksOfBatchAndStream(t(3), streamId)) + subject.getBlocksOfBatchAndStream(t(1), streamId) should be( + base.getBlocksOfBatchAndStream(t(1), streamId)) + subject.getBlocksOfBatchAndStream(t(0), streamId) should be(Nil) + } + + // ask the tracker to clean up some old files + tracker.cleanupOldBatches(t(3), waitForCompletion = true) + assert(getWriteAheadLogFiles().length === 3) + + val tracker2 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker2) + + // rewrite first file + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 4) + // make sure trackers are consistent + val tracker3 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker3) + + // rewrite second file + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + assert(getWriteAheadLogFiles().length === 5) + // make sure trackers are consistent + val tracker4 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker4) + } + /** * Create tracker object with the optional provided clock. Use fake clock if you * want to control time by manually incrementing it to test log clean. @@ -228,11 +298,30 @@ class ReceivedBlockTrackerSuite BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } + /** + * Write received block tracker events to a file manually. + */ + def writeEventsManually(filePath: String, events: Seq[ReceivedBlockTrackerLogEvent]): Unit = { + val writer = HdfsUtils.getOutputStream(filePath, hadoopConf) + events.foreach { event => + val bytes = Utils.serialize(event) + writer.writeInt(bytes.size) + writer.write(bytes) + } + writer.close() + } + /** Get all the data written in the given write ahead log file. */ def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { getWrittenLogData(Seq(logFile)) } + /** Get the log file name for the given log start time. */ + def getLogFileName(time: Long, rollingIntervalSecs: Int = 1): String = { + checkpointDirectory.toString + File.separator + "receivedBlockMetadata" + + File.separator + s"log-$time-${time + rollingIntervalSecs * 1000}" + } + /** * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. @@ -241,8 +330,13 @@ class ReceivedBlockTrackerSuite : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq - }.map { byteBuffer => - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.flatMap { byteBuffer => + val validBuffer = if (WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) { + Utils.deserialize[Array[Array[Byte]]](byteBuffer.array()).map(ByteBuffer.wrap) + } else { + Array(byteBuffer) + } + validBuffer.map(b => Utils.deserialize[ReceivedBlockTrackerLogEvent](b.array())) }.toList } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala new file mode 100644 index 000000000000..6d388d9624d9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.{SparkConf, SparkEnv} + +class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + + override def afterAll(): Unit = { + StreamingContext.getActive().map { _.stop() } + } + + testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithoutWAL("createBlockRDD creates correct BlockRDD with block info") { receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = false) } + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.forall(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + testWithoutWAL("createBlockRDD filters non-existent blocks before creating BlockRDD") { + receiverStream => + val presentBlockInfos = Seq.fill(2)(createBlockInfo(withWALInfo = false, createBlock = true)) + val absentBlockInfos = Seq.fill(3)(createBlockInfo(withWALInfo = false, createBlock = false)) + val blockInfos = presentBlockInfos ++ absentBlockInfos + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.exists(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + require(blockIds.exists(blockId => !SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === presentBlockInfos.map { _.blockId}) + } + + testWithWAL("createBlockRDD creates empty WALBackedBlockRDD when no block info") { + receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithWAL( + "createBlockRDD creates correct WALBackedBlockRDD with all block info having WAL info") { + receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = true) } + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[WriteAheadLogBackedBlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) + } + + testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + receiverStream => + val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } + val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } + val blockInfos = blockInfos1 ++ blockInfos2 + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + + private def testWithoutWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"Without WAL enabled: $msg") { + runTest(enableWAL = false, body) + } + } + + private def testWithWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"With WAL enabled: $msg") { + runTest(enableWAL = true, body) + } + } + + private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { + val conf = new SparkConf() + conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") + conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) + val ssc = new StreamingContext(conf, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } + + /** + * Create a block info for input to the ReceiverInputDStream.createBlockRDD + * @param withWALInfo Create block with WAL info in it + * @param createBlock Actually create the block in the BlockManager + * @return + */ + private def createBlockInfo( + withWALInfo: Boolean, + createBlock: Boolean = true): ReceivedBlockInfo = { + val blockId = new StreamBlockId(0, Random.nextLong()) + if (createBlock) { + SparkEnv.get.blockManager.putSingle(blockId, 1, StorageLevel.MEMORY_ONLY, tellMaster = true) + require(SparkEnv.get.blockManager.master.contains(blockId)) + } + val storeResult = if (withWALInfo) { + new WriteAheadLogBasedStoreResult(blockId, None, new WriteAheadLogRecordHandle { }) + } else { + new BlockManagerBasedStoreResult(blockId, None) + } + new ReceivedBlockInfo(0, None, None, storeResult) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 13b4d17c8618..01279b34f73d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } } - test("block generator") { - val blockGeneratorListener = new FakeBlockGeneratorListener - val blockIntervalMs = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") - val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) - val expectedBlocks = 5 - val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) - val generatedData = new ArrayBuffer[Int] - - // Generate blocks - val startTime = System.currentTimeMillis() - blockGenerator.start() - var count = 0 - while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator.addData(count) - generatedData += count - count += 1 - Thread.sleep(10) - } - blockGenerator.stop() - - val recordedData = blockGeneratorListener.arrayBuffers.flatten - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.toSet === generatedData.toSet) - } - ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 @@ -348,6 +322,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } override protected def onReceiverStart(): Boolean = true + + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + null + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 000000000000..c4a01eaea739 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.collection.{immutable, mutable, Map} +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} +import org.apache.spark.util.Utils + +class StateMapSuite extends SparkFunSuite { + + test("EmptyStateMap") { + val map = new EmptyStateMap[Int, Int] + intercept[scala.NotImplementedError] { + map.put(1, 1, 1) + } + assert(map.get(1) === None) + assert(map.getByTime(10000).isEmpty) + assert(map.getAll().isEmpty) + map.remove(1) // no exception + assert(map.copy().eq(map)) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10))) + + map.put(2, 200, 20) + assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20))) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + // Create child map and make changes + val map = parentMap.copy() + assert(map.get(1) === None) + assert(map.get(2) === Some(200)) + assert(map.getByTime(10).toSet === Set((2, 200, 2))) + assert(map.getByTime(2).toSet === Set.empty) + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + assert(map.get(3) === Some(300)) + map.put(4, 400, 4) + assert(map.get(4) === Some(400)) + assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3))) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) + assert(map.get(4) === None) // item added in this map, then removed in this map + map.remove(2) + assert(map.get(2) === None) // item removed in parent map, then added in this map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) + assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map + map.put(2, 2000, 200) + assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map + map.put(3, 3000, 300) + assert(map.get(3) === Some(3000)) // item added + updated in this map + map.put(4, 4000, 400) + assert(map.get(4) === Some(4000)) // item removed + updated in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + map.remove(2) // remove item present in parent map, so that its not visible in child map + + // Create child map and see availability of items + val childMap = map.copy() + assert(childMap.getAll().toSet === map.getAll().toSet) + assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map + assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map + assert(childMap.get(3) === Some(3000)) // item added and updated in parent map + + childMap.put(2, 20000, 200) + assert(childMap.get(2) === Some(20000)) // item map + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + + map1.put(1, 100, 1) + map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") + + val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + + map2.put(3, 300, 3) + map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") + + val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") + map3.put(3, 600, 3) + map3.remove(2) + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + // Make large delta chain with length more than deltaChainThreshold + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), 1) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + } + + test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { + /* + * This tests the map using all permutations of sequences operations, across multiple map + * copies as well as between copies. It is to ensure complete coverage, though it is + * kind of hard to debug this. It is set up as follows. + * + * - For any key, there can be 2 types of update ops on a state map - put or remove + * + * - These operations are done on a test map in "sets". After each set, the map is "copied" + * to create a new map, and the next set of operations are done on the new one. This tests + * whether the map data persistes correctly across copies. + * + * - Within each set, there are a number of operations to test whether the map correctly + * updates and removes data without affecting the parent state map. + * + * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types + * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence + * of operations, which we will test with different keys. + * + * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that + * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys. + * _______________________________________________ + * | | Set1 | Set2 | + * | |-----------------|-----------------| + * | | Op1 Op2 |c| Op3 Op4 | + * |---------|----------------|o|----------------| + * | key 0 | put put |p| put put | + * | key 1 | put put |y| put rem | + * | key 2 | put put | | rem put | + * | key 3 | put put |t| rem rem | + * | key 4 | put rem |h| put put | + * | key 5 | put rem |e| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |s| rem rem | + * | key 8 | rem put |t| put put | + * | key 9 | rem put |a| put rem | + * | key 10 | rem put |t| rem put | + * | key 11 | rem put |e| rem rem | + * | key 12 | rem rem | | put put | + * | key 13 | rem rem |m| put rem | + * | key 14 | rem rem |a| rem put | + * | key 15 | rem rem |p| rem rem | + * |_________|________________|_|________________| + */ + + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numSets = 3 + val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set + val numTotalOps = numOpsPerSet * numSets + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + val refMap = new mutable.HashMap[Int, (Int, Long)]() + var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + var prevSetStateMap: StateMap[Int, Int] = null + + var time = 1L + + for (setId <- 0 until numSets) { + for (opInSetId <- 0 until numOpsPerSet) { + val opId = setId * numOpsPerSet + opInSetId + for (keyId <- 0 until numKeys) { + time += 1 + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = + (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, time) + refMap.put(keyId, (value, time)) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + + // Test whether the current state map after all key updates is correct + assertMap(stateMap, refMap, time, "State map does not match reference map") + + // Test whether the previous map before copy has not changed + if (prevSetStateMap != null && prevSetRefMap != null) { + assertMap(prevSetStateMap, prevSetRefMap, time, + "Parent state map somehow got modified, does not match corresponding reference map") + } + } + + // Copy the map and remember the previous maps for future tests + prevSetStateMap = stateMap + prevSetRefMap = refMap.toMap + stateMap = stateMap.copy() + + // Assert that the copied map has the same data + assertMap(stateMap, prevSetRefMap, time, + "State map does not match reference map after copying") + } + assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") + } + + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + + // Assert whether all the data and operations on a state map matches that of a reference state map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: StateMap[Int, Int], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.getAll().map { _._1 }) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId)) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet) + } + } + } + + // Assert whether all the data and operations on a state map matches that of a reference map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: Map[Int, (Int, Long)], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === + refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.keys) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 }) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + val expectedRecords = + refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) } + assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet) + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index b7db280f6358..860fac29c0ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound.contains("Streaming job from")) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) @@ -726,16 +758,42 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } test("queueStream doesn't support checkpointing") { - val checkpointDir = Utils.createTempDir() - ssc = new StreamingContext(master, appName, batchDuration) - val rdd = ssc.sparkContext.parallelize(1 to 10) - ssc.queueStream[Int](Queue(rdd)).print() - ssc.checkpoint(checkpointDir.getAbsolutePath) - val e = intercept[NotSerializableException] { - ssc.start() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + def creatingFunction(): StreamingContext = { + val _ssc = new StreamingContext(conf, batchDuration) + val rdd = _ssc.sparkContext.parallelize(1 to 10) + _ssc.checkpoint(checkpointDirectory) + _ssc.queueStream[Int](Queue(rdd)).register() + _ssc + } + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + val e = intercept[SparkException] { + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) } // StreamingContext.validate changes the message, so use "contains" here - assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + assert(e.getCause.getMessage.contains("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.")) + } + + test("Creating an InputDStream but not using it should not crash") { + ssc = new StreamingContext(master, appName, batchDuration) + val input1 = addInputStream(ssc) + val input2 = addInputStream(ssc) + val output = new TestOutputStream(input2) + output.register() + val batchCount = new BatchCounter(ssc) + ssc.start() + // Just wait for completing 2 batches to make sure it triggers + // `DStream.getMaxInputStreamRememberDuration` + batchCount.waitUntilBatchesCompleted(2, 10000) + // Throw the exception if crash + ssc.awaitTerminationOrTimeout(1) + ssc.stop() } def addInputStream(s: StreamingContext): DStream[Int] = { @@ -789,7 +847,8 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging } def onStop() { - // no clean to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own, so just wait for it. + receivingThreadOption.foreach(_.join()) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index d840c349bbbc..04cd5bdc26be 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver @@ -140,6 +141,113 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + test("output operation reporting") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count()) + inputStream.foreachRDD(_.collect()) + inputStream.foreachRDD(_.count()) + + val collector = new OutputOperationInfoCollector + ssc.addStreamingListener(collector) + + ssc.start() + try { + eventually(timeout(30 seconds), interval(20 millis)) { + collector.startedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + collector.completedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + } + } finally { + ssc.stop() + } + } + + test("don't call ssc.stop in listener") { + ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + startStreamingContextAndCallStop(ssc) + } + + test("onBatchCompleted with successful batch") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + val failureReasons = startStreamingContextAndCollectFailureReasons(ssc) + assert(failureReasons != null && failureReasons.isEmpty, + "A successful batch should not set errorMessage") + } + + test("onBatchCompleted with failed batch and one failed job") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD { _ => + throw new RuntimeException("This is a failed job") + } + + // Check if failureReasons contains the correct error message + val failureReasons = startStreamingContextAndCollectFailureReasons(ssc, isFailed = true) + assert(failureReasons != null) + assert(failureReasons.size === 1) + assert(failureReasons.contains(0)) + assert(failureReasons(0).contains("This is a failed job")) + } + + test("onBatchCompleted with failed batch and multiple failed jobs") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD { _ => + throw new RuntimeException("This is a failed job") + } + inputStream.foreachRDD { _ => + throw new RuntimeException("This is another failed job") + } + + // Check if failureReasons contains the correct error messages + val failureReasons = + startStreamingContextAndCollectFailureReasons(ssc, isFailed = true) + assert(failureReasons != null) + assert(failureReasons.size === 2) + assert(failureReasons.contains(0)) + assert(failureReasons.contains(1)) + assert(failureReasons(0).contains("This is a failed job")) + assert(failureReasons(1).contains("This is another failed job")) + } + + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { + val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) + _ssc.addStreamingListener(contextStoppingCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + if (!batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)) { + fail("The first batch cannot complete in 10 seconds") + } + // When reaching here, we can make sure `StreamingContextStoppingCollector` won't call + // `ssc.stop()`, so it's safe to call `_ssc.stop()` now. + _ssc.stop() + assert(contextStoppingCollector.sparkExSeen) + } + + private def startStreamingContextAndCollectFailureReasons( + _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { + val failureReasonsCollector = new FailureReasonsCollector() + _ssc.addStreamingListener(failureReasonsCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + if (isFailed) { + intercept[RuntimeException] { + _ssc.awaitTerminationOrTimeout(10000) + } + } + _ssc.stop() + failureReasonsCollector.failureReasons.toMap + } + /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { for (i <- 1 until seq.size) { @@ -191,6 +299,22 @@ class ReceiverInfoCollector extends StreamingListener { } } +/** Listener that collects information on processed output operations */ +class OutputOperationInfoCollector extends StreamingListener { + val startedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val completedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + startedOutputOperationIds += outputOperationStarted.outputOperationInfo.id + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + completedOutputOperationIds += outputOperationCompleted.outputOperationInfo.id + } +} + class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_ONLY) with Logging { def onStart() { Future { @@ -205,3 +329,41 @@ class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_O } def onStop() { } } + +/** + * A StreamingListener that saves all latest `failureReasons` in a batch. + */ +class FailureReasonsCollector extends StreamingListener { + + val failureReasons = new HashMap[Int, String] with SynchronizedMap[Int, String] + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + outputOperationCompleted.outputOperationInfo.failureReason.foreach { f => + failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + } + } +} +/** + * A StreamingListener that calls StreamingContext.stop(). + */ +class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { + @volatile var sparkExSeen = false + + private var isFirstBatch = true + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + if (isFirstBatch) { + // We should only call `ssc.stop()` in the first batch. Otherwise, it's possible that the main + // thread is calling `ssc.stop()`, while StreamingContextStoppingCollector is also calling + // `ssc.stop()` in the listener thread, which becomes a dead-lock. + isFirstBatch = false + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 0d58a7b54412..be0f4636a6cb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -98,7 +98,7 @@ class TestOutputStream[T: ClassTag]( ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) @@ -122,7 +122,7 @@ class TestOutputStreamWithPartitions[T: ClassTag]( extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index a08578680cff..a5744a9009c1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -100,8 +100,8 @@ class UISeleniumSuite // Check stat table val statTableHeaders = findAll(cssSelector("#stat-table th")).map(_.text).toSeq statTableHeaders.exists( - _.matches("Timelines \\(Last \\d+ batches, \\d+ active, \\d+ completed\\)")) should be - (true) + _.matches("Timelines \\(Last \\d+ batches, \\d+ active, \\d+ completed\\)") + ) should be (true) statTableHeaders should contain ("Histograms") val statTableCells = findAll(cssSelector("#stat-table td")).map(_.text).toSeq @@ -117,11 +117,11 @@ class UISeleniumSuite findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", - "Status") + "Output Ops: Succeeded/Total", "Status") } findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", - "Total Delay (?)") + "Total Delay (?)", "Output Ops: Succeeded/Total") } val batchLinks = @@ -138,7 +138,7 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Job Id", "Duration", + List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala new file mode 100644 index 000000000000..aa95bd33dda9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils + +class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { + + private var sc: SparkContext = null + private var checkpointDir: File = _ + + override def beforeAll(): Unit = { + sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + Utils.deleteRecursively(checkpointDir) + } + + override def sparkContext: SparkContext = sc + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, Time(123)) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("updating state and generating mapped data in MapWithStateRDDRecord") { + + val initialTime = 1000L + val updatedTime = 2000L + val thresholdTime = 1500L + @volatile var functionCalled = false + + /** + * Assert that applying given data on a prior record generates correct updated record, with + * correct state map and mapped data + */ + def assertRecordUpdate( + initStates: Iterable[Int], + data: Iterable[String], + expectedStates: Iterable[(Int, Long)], + timeoutThreshold: Option[Long] = None, + removeTimedoutData: Boolean = false, + expectedOutput: Iterable[Int] = None, + expectedTimingOutStates: Iterable[Int] = None, + expectedRemovedStates: Iterable[Int] = None + ): Unit = { + val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() + initStates.foreach { s => initialStateMap.put("key", s, initialTime) } + functionCalled = false + val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val dataIterator = data.map { v => ("key", v) }.iterator + val removedStates = new ArrayBuffer[Int] + val timingOutStates = new ArrayBuffer[Int] + /** + * Mapping function that updates/removes state based on instructions in the data, and + * return state (when instructed or when state is timing out). + */ + def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { + functionCalled = true + + assert(t.milliseconds === updatedTime, "mapping func called with wrong time") + + data match { + case Some("noop") => + None + case Some("get-state") => + Some(state.getOption().getOrElse(-1)) + case Some("update-state") => + if (state.exists) state.update(state.get + 1) else state.update(0) + None + case Some("remove-state") => + removedStates += state.get() + state.remove() + None + case None => + assert(state.isTimingOut() === true, "State is not timing out when data = None") + timingOutStates += state.get() + None + case _ => + fail("Unexpected test data") + } + } + + val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + Some(record), dataIterator, testFunc, + Time(updatedTime), timeoutThreshold, removeTimedoutData) + + val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } + assert(updatedStateData.toSet === expectedStates.toSet, + "states do not match after updating the MapWithStateRDDRecord") + + assert(updatedRecord.mappedData.toSet === expectedOutput.toSet, + "mapped data do not match after updating the MapWithStateRDDRecord") + + assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + } + + // No data, no state should be changed, function should not be called, + assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) + assert(functionCalled === false) + assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) + assert(functionCalled === false) + + // Data present, function should be called irrespective of whether state exists + assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), + expectedStates = Seq((0, initialTime))) + assert(functionCalled === true) + assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) + assert(functionCalled === true) + + // Function called with right state data + assertRecordUpdate(initStates = None, data = Seq("get-state"), + expectedStates = None, expectedOutput = Seq(-1)) + assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), + expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) + + // Update state and timestamp, when timeout not present + assertRecordUpdate(initStates = Nil, data = Seq("update-state"), + expectedStates = Seq((0, updatedTime))) + assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), + expectedStates = Seq((1, updatedTime))) + + // Remove state + assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), + expectedStates = Nil, expectedRemovedStates = Seq(345)) + + // State strictly older than timeout threshold should be timed out + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime), removeTimedoutData = true, + expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) + + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Seq(123)) + + // State should not be timed out after it has received data + assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) + assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + + } + + test("states generated by MapWithStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, Time(initTime)).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: MapWithStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = { + + // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only touch which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + MapWithStateRDDSuite.touchedStateKeys.clear() + + val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + MapWithStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( // should remove k3's state + rdd7, Seq(("k3", 2)), Set()) + } + + test("checkpointing") { + /** + * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs + * - the data RDD and the parent MapWithStateRDD. + */ + def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) } + .collect.toSet + } + + /** Generate MapWithStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : MapWithStateRDD[Int, Int, Int, Int] = { + MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate MapWithStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = { + + // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent + new MapWithStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + + test("checkpointing empty state RDD") { + val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + + /** Assert whether the `mapWithState` operation generates expected results */ + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: MapWithStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T], + doFullScan: Boolean = false + ): MapWithStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new MapWithStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None) + if (doFullScan) newStateRDD.setFullScan() + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist().count() + assertRDD(newStateRDD, expectedStates, expectedMappedData) + newStateRDD + } + + /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */ + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + stateRDD: MapWithStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T]): Unit = { + val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet + assert(states === expectedStates, + "states after mapWithState operation were not as expected") + assert(mappedData === expectedMappedData, + "mapped data after mapWithState operation were not as expected") + } +} + +object MapWithStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala new file mode 100644 index 000000000000..92ad9fe52b77 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} + +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + + private val blockIntervalMs = 10 + private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") + @volatile private var blockGenerator: BlockGenerator = null + + after { + if (blockGenerator != null) { + blockGenerator.stop() + } + } + + test("block generation and data callbacks") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + + require(blockIntervalMs > 5) + require(listener.onAddDataCalled === false) + require(listener.onGenerateBlockCalled === false) + require(listener.onPushBlockCalled === false) + + // Verify that creating the generator does not start it + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + assert(blockGenerator.isActive() === false, "block generator active before start()") + assert(blockGenerator.isStopped() === false, "block generator stopped before start()") + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + + // Verify start marks the generator active, but does not call the callbacks + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator active after start()") + assert(blockGenerator.isStopped() === false, "block generator stopped after start()") + withClue("callbacks called before adding data") { + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + + // Verify whether addData() adds data that is present in generated blocks + val data1 = 1 to 10 + data1.foreach { blockGenerator.addData _ } + withClue("callbacks called on adding data without metadata and without block generation") { + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + clock.advance(blockIntervalMs) // advance clock to generate blocks + withClue("blocks not generated or pushed") { + eventually(timeout(1 second)) { + assert(listener.onGenerateBlockCalled === true) + assert(listener.onPushBlockCalled === true) + } + } + listener.pushedData should contain theSameElementsInOrderAs (data1) + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + + // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + val data2 = 11 to 20 + val metadata2 = data2.map { _.toString } + data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } + assert(listener.onAddDataCalled === true) + listener.addedData should contain theSameElementsInOrderAs (data2) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + } + + // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + val data3 = 21 to 30 + val metadata3 = "metadata" + blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + } + + // Stop the block generator by starting the stop on a different thread and + // then advancing the manual clock for the stopping to proceed. + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + clock.advance(blockIntervalMs) + assert(blockGenerator.isStopped() === true) + } + thread.join() + + // Verify that the generator cannot be used any more + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, 1) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), 1) + } + intercept[SparkException] { + blockGenerator.start() + } + blockGenerator.stop() // Calling stop again should be fine + } + + test("stop ensures correct shutdown") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + require(listener.onGenerateBlockCalled === false) + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator") + assert(blockGenerator.isStopped() === false) + + val data = 1 to 1000 + data.foreach { blockGenerator.addData _ } + + // Verify that stop() shutdowns everything in the right order + // - First, stop receiving new data + // - Second, wait for final block with all buffered data to be generated + // - Finally, wait for all blocks to be pushed + clock.advance(1) // to make sure that the timer for another interval to complete + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(blockGenerator.isActive() === false) + } + assert(blockGenerator.isStopped() === false) + + // Verify that data cannot be added + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, null) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), null) + } + + // Verify that stop() stays blocked until another block containing all the data is generated + // This intercept always succeeds, as the body either will either throw a timeout exception + // (expected as stop() should never complete) or a SparkException (unexpected as stop() + // completed and thread terminated). + val exception = intercept[Exception] { + failAfter(200 milliseconds) { + thread.join() + throw new SparkException( + "BlockGenerator.stop() completed before generating timer was stopped") + } + } + exception should not be a [SparkException] + + + // Verify that the final data is present in the final generated block and + // pushed before complete stop + assert(blockGenerator.isStopped() === false) // generator has not stopped yet + eventually(timeout(10 seconds), interval(10 milliseconds)) { + // Keep calling `advance` to avoid blocking forever in `clock.waitTillTime` + clock.advance(blockIntervalMs) + assert(thread.isAlive === false) + } + assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped + assert(listener.pushedData === data, "All data not pushed by stop()") + } + + test("block push errors are reported") { + val listener = new TestBlockGeneratorListener { + @volatile var errorReported = false + override def onPushBlock( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + throw new SparkException("test") + } + override def onError(message: String, throwable: Throwable): Unit = { + errorReported = true + } + } + blockGenerator = new BlockGenerator(listener, 0, conf) + blockGenerator.start() + assert(listener.errorReported === false) + blockGenerator.addData(1) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(listener.errorReported === true) + } + blockGenerator.stop() + } + + /** + * Helper method to stop the block generator with manual clock in a different thread, + * so that the main thread can advance the clock that allows the stopping to proceed. + */ + private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = { + val thread = new Thread() { + override def run(): Unit = { + blockGenerator.stop() + } + } + thread.start() + thread + } + + /** A listener for BlockGenerator that records the data in the callbacks */ + private class TestBlockGeneratorListener extends BlockGeneratorListener { + val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + @volatile var onGenerateBlockCalled = false + @volatile var onAddDataCalled = false + @volatile var onPushBlockCalled = false + + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + pushedData ++= arrayBuffer + onPushBlockCalled = true + } + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = { + onGenerateBlockCalled = true + } + override def onAddData(data: Any, metadata: Any): Unit = { + addedData += data + addedMetadata += metadata + onAddDataCalled = true + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index a2dbae149f31..9b6cd4bc4e31 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -56,7 +56,8 @@ class JobGeneratorSuite extends TestSuiteBase { // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) // 5. verify whether 3rd batch's block metadata still exists // - test("SPARK-6222: Do not clear received block data too soon") { + // TODO: SPARK-7420 enable this test + ignore("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() val testConf = conf diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 921da773f6c1..1eb52b7029a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -18,10 +18,7 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable -import scala.reflect.ClassTag -import scala.util.control.NonFatal -import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -32,72 +29,63 @@ class RateControllerSuite extends TestSuiteBase { override def useManualClock: Boolean = false - test("rate controller publishes updates") { + override def batchDuration: Duration = Milliseconds(50) + + test("RateController - rate controller publishes updates after batches complete") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) + val dstream = new RateTestInputDStream(ssc) dstream.register() ssc.start() eventually(timeout(10.seconds)) { - assert(dstream.publishCalls > 0) + assert(dstream.publishedRates > 0) } } } - test("publish rates reach receivers") { + test("ReceiverRateController - published rates reach receivers") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) { + val estimator = new ConstantEstimator(100) + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, estimator)) } dstream.register() - SingletonTestRateReceiver.reset() ssc.start() - eventually(timeout(10.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + // Wait for receiver to start + eventually(timeout(5.seconds)) { + RateTestReceiver.getActive().nonEmpty } - } - } - test("multiple publish rates reach receivers") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => - val rates = Seq(100L, 200L, 300L) - - val dstream = new RateLimitInputDStream(ssc) { - override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + // Update rate in the estimator and verify whether the rate was published to the receiver + def updateRateAndVerify(rate: Long): Unit = { + estimator.updateRate(rate) + eventually(timeout(5.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate) + } } - SingletonTestRateReceiver.reset() - dstream.register() - - val observedRates = mutable.HashSet.empty[Long] - ssc.start() - eventually(timeout(20.seconds)) { - dstream.getCurrentRateLimit.foreach(observedRates += _) - // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver - observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + // Verify multiple rate update + Seq(100, 200, 300).foreach { rate => + updateRateAndVerify(rate) } } } } -private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { - private var idx: Int = 0 +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { - private def nextRate(): Double = { - val rate = rates(idx) - idx = (idx + 1) % rates.size - rate + def updateRate(newRate: Long): Unit = { + rate = newRate } def compute( time: Long, elements: Long, processingDelay: Long, - schedulingDelay: Long): Option[Double] = Some(nextRate()) + schedulingDelay: Long): Option[Double] = Some(rate) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 0418d776ecc9..05b4e66c63ac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -20,73 +20,96 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, HostTaskLocation, TaskLocation} class ReceiverSchedulingPolicySuite extends SparkFunSuite { val receiverSchedulingPolicy = new ReceiverSchedulingPolicy test("rescheduleReceiver: empty executors") { - val scheduledExecutors = + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) - assert(scheduledExecutors === Seq.empty) + assert(scheduledLocations === Seq.empty) } test("rescheduleReceiver: receiver preferredLocation") { + val executors = Seq(ExecutorCacheTaskLocation("host2", "2")) val receiverTrackingInfoMap = Map( 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( - 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) - assert(scheduledExecutors.toSet === Set("host1", "host2")) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors) + assert(scheduledLocations.toSet === Set(HostTaskLocation("host1"), executors(0))) } - test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { - val executors = Seq("host1", "host2", "host3", "host4", "host5") - // host3 is idle + test("rescheduleReceiver: return all idle executors if there are any idle executors") { + val executors = (1 to 5).map(i => ExecutorCacheTaskLocation(s"host$i", s"$i")) + // executor 1 is busy, others are idle. val receiverTrackingInfoMap = Map( - 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some(executors(0)))) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( 1, None, receiverTrackingInfoMap, executors) - assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + assert(scheduledLocations.toSet === executors.tail.toSet) } - test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { - val executors = Seq("host1", "host2", "host3", "host4", "host5") - // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 - // host4 and host5 are idle + test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { + val executors = Seq( + ExecutorCacheTaskLocation("host1", "1"), + ExecutorCacheTaskLocation("host2", "2"), + ExecutorCacheTaskLocation("host3", "3"), + ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5") + ) + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 val receiverTrackingInfoMap = Map( - 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), - 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), - 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( - 3, None, receiverTrackingInfoMap, executors) - assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, + Some(ExecutorCacheTaskLocation("host1", "1"))), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host2", "2"), ExecutorCacheTaskLocation("host3", "3"))), + None), + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host1", "1"), ExecutorCacheTaskLocation("host3", "3"))), + None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5"))), None)) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( + 4, None, receiverTrackingInfoMap, executors) + val expectedScheduledLocations = Set( + ExecutorCacheTaskLocation("host2", "2"), + ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5") + ) + assert(scheduledLocations.toSet === expectedScheduledLocations) } test("scheduleReceivers: " + "schedule receivers evenly when there are more receivers than executors") { val receivers = (0 until 6).map(new RateTestReceiver(_)) - val executors = (10000 until 10003).map(port => s"localhost:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 2 receivers running on each executor and each receiver has one executor - scheduledExecutors.foreach { case (receiverId, executors) => - assert(executors.size == 1) - numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + scheduledLocations.foreach { case (receiverId, locations) => + assert(locations.size == 1) + assert(locations(0).isInstanceOf[ExecutorCacheTaskLocation]) + numReceiversOnExecutor(locations(0)) = numReceiversOnExecutor.getOrElse(locations(0), 0) + 1 } assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) } - test("scheduleReceivers: " + "schedule receivers evenly when there are more executors than receivers") { val receivers = (0 until 3).map(new RateTestReceiver(_)) - val executors = (10000 until 10006).map(port => s"localhost:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 6).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 1 receiver running on each executor and each receiver has two executors - scheduledExecutors.foreach { case (receiverId, executors) => - assert(executors.size == 2) - executors.foreach { l => + scheduledLocations.foreach { case (receiverId, locations) => + assert(locations.size == 2) + locations.foreach { l => + assert(l.isInstanceOf[ExecutorCacheTaskLocation]) numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 } } @@ -96,35 +119,43 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) - val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ - (10003 until 10006).map(port => s"localhost2:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) ++ + (3 until 6).map(executorId => + ExecutorCacheTaskLocation("localhost2", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 1 receiver running on each executor and each receiver has 1 executor - scheduledExecutors.foreach { case (receiverId, executors) => + scheduledLocations.foreach { case (receiverId, executors) => assert(executors.size == 1) executors.foreach { l => + assert(l.isInstanceOf[ExecutorCacheTaskLocation]) numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 } } assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) // Make sure we schedule the receivers to their preferredLocations val executorsForReceiversWithPreferredLocation = - scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + scheduledLocations.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) // We can simply check the executor set because we only know each receiver only has 1 executor assert(executorsForReceiversWithPreferredLocation.toSet === - (10000 until 10003).map(port => s"localhost:${port}").toSet) + (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString) + ).toSet) } test("scheduleReceivers: return empty if no receiver") { - assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + val scheduledLocations = receiverSchedulingPolicy. + scheduleReceivers(Seq.empty, Seq(ExecutorCacheTaskLocation("localhost", "1"))) + assert(scheduledLocations.isEmpty) } test("scheduleReceivers: return empty scheduled executors if no executors") { val receivers = (0 until 3).map(new RateTestReceiver(_)) - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) - scheduledExecutors.foreach { case (receiverId, executors) => + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledLocations.foreach { case (receiverId, executors) => assert(executors.isEmpty) } } + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index afad5f16dbc7..3bd8d086abf7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -17,119 +17,193 @@ package org.apache.spark.streaming.scheduler +import scala.collection.mutable.ArrayBuffer + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLocality} +import org.apache.spark.scheduler.TaskLocality.TaskLocality +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { - val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - - test("Receiver tracker - propagates rate limit") { - withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false - - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true - } - } - ssc.addStreamingListener(ReceiverStartedWaiter) + test("send rate update to receivers") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) + val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) tracker.start() try { // we wait until the Receiver has registered with the tracker, // otherwise our rate update is lost eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) + assert(RateTestReceiver.getActive().nonEmpty) } + + + // Verify that the rate of the block generator in the receiver get updated + val activeReceiver = RateTestReceiver.getActive().get tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + eventually(timeout(5 seconds)) { + assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit, + "default block generator did not receive rate update") + assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit, + "other block generator did not receive rate update") } } finally { tracker.stop(false) } } } + + test("should restart receiver after stopping it") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + @volatile var startTimes = 0 + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + startTimes += 1 + } + }) + val input = ssc.receiverStream(new StoppableReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + StoppableReceiver.shouldStop = true + eventually(timeout(10 seconds), interval(10 millis)) { + // The receiver is stopped once, so if it's restarted, it should be started twice. + assert(startTimes === 2) + } + } + } + + test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") { + // Use ManualClock to prevent from starting batches so that we can make sure the only task is + // for starting the Receiver + val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc => + @volatile var receiverTaskLocality: TaskLocality = null + ssc.sparkContext.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + receiverTaskLocality = taskStart.taskInfo.taskLocality + } + }) + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + eventually(timeout(10 seconds), interval(10 millis)) { + // If preferredLocations is set correctly, receiverTaskLocality should be PROCESS_LOCAL + assert(receiverTaskLocality === TaskLocality.PROCESS_LOCAL) + } + } + } } -/** - * An input DStream with a hard-coded receiver that gives access to internals for testing. - * - * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, - * or otherwise you may get {{{NotSerializableException}}} when trying to serialize - * the receiver. - * @see [[[SingletonDummyReceiver]]]. - */ -private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** An input DStream with for testing rate controlling */ +private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver - - def getCurrentRateLimit: Option[Long] = { - invokeExecutorMethod.getCurrentRateLimit - } + override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) @volatile - var publishCalls = 0 + var publishedRates = 0 override val rateController: Option[RateController] = { - Some(new RateController(id, new ConstantEstimator(100.0)) { + Some(new RateController(id, new ConstantEstimator(100)) { override def publish(rate: Long): Unit = { - publishCalls += 1 + publishedRates += 1 } }) } +} - private def invokeExecutorMethod: ReceiverSupervisor = { - val c = classOf[Receiver[_]] - val ex = c.getDeclaredMethod("executor") - ex.setAccessible(true) - ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] +/** A receiver implementation for testing rate controlling */ +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + private lazy val customBlockGenerator = supervisor.createBlockGenerator( + new BlockGeneratorListener { + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = {} + override def onAddData(data: Any, metadata: Any): Unit = {} + } + ) + + setReceiverId(receiverId) + + override def onStart(): Unit = { + customBlockGenerator + RateTestReceiver.registerReceiver(this) + } + + override def onStop(): Unit = { + RateTestReceiver.deregisterReceiver() + } + + override def preferredLocation: Option[String] = host + + def getDefaultBlockGeneratorRateLimit(): Long = { + supervisor.getCurrentRateLimit + } + + def getCustomBlockGeneratorRateLimit(): Long = { + customBlockGenerator.getCurrentLimit } } /** - * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when - * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being - * serialized when receivers are installed on executors. - * - * @note It's necessary to be a top-level object, or else serialization would create another - * one on the executor side and we won't be able to read its rate limit. + * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver + * instance. */ -private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { +private[streaming] object RateTestReceiver { + @volatile private var activeReceiver: RateTestReceiver = null + + def registerReceiver(receiver: RateTestReceiver): Unit = { + activeReceiver = receiver + } - /** Reset the object to be usable in another test. */ - def reset(): Unit = { - executor_ = null + def deregisterReceiver(): Unit = { + activeReceiver = null } + + def getActive(): Option[RateTestReceiver] = Option(activeReceiver) } /** - * Dummy receiver implementation + * A custom receiver that could be stopped via StoppableReceiver.shouldStop */ -private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) - extends Receiver[Int](StorageLevel.MEMORY_ONLY) { +class StoppableReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - setReceiverId(receiverId) + var receivingThreadOption: Option[Thread] = None - override def onStart(): Unit = {} + def onStart() { + val thread = new Thread() { + override def run() { + while (!StoppableReceiver.shouldStop) { + Thread.sleep(10) + } + StoppableReceiver.this.stop("stop") + } + } + thread.start() + } - override def onStop(): Unit = {} + def onStop() { + StoppableReceiver.shouldStop = true + receivingThreadOption.foreach(_.join()) + // Reset it so as to restart it + StoppableReceiver.shouldStop = false + } +} - override def preferredLocation: Option[String] = host +object StoppableReceiver { + @volatile var shouldStop = false } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala index 97c32d8f2d59..a1af95be81c8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -36,72 +36,89 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { test("estimator checks ranges") { intercept[IllegalArgumentException] { - new PIDRateEstimator(0, 1, 2, 3) + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, -1, 2, 3) + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, -1, 3) + new PIDRateEstimator(100, 0, integral = -1, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, 0, -1) + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) } } - private def createDefaultEstimator: PIDRateEstimator = { - new PIDRateEstimator(20, 1D, 0D, 0D) - } - - test("first bound is None") { - val p = createDefaultEstimator + test("first estimate is None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) should equal(None) } - test("second bound is rate") { - val p = createDefaultEstimator + test("second estimate is not None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) // 1000 elements / s p.compute(10, 10, 10, 0) should equal(Some(1000)) } - test("works even with no time between updates") { - val p = createDefaultEstimator + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) - p.compute(10, 10, 10, 0) - p.compute(10, 10, 10, 0) should equal(None) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) } - test("bound is never negative") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing // this might point the estimator to try and decrease the bound, but we test it never - // goes below zero, which would be nonsensical. + // goes below the min rate, which would be nonsensical. val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.fill(50)(0) // no processing + val elements = List.fill(50)(1) // no processing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(100) // strictly positive accumulation val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.fill(49)(Some(0D))) + res.tail should equal(List.fill(49)(Some(minRate))) } test("with no accumulated or positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an increasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.tabulate(50)(x => x * 20) // increasing + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(0) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) } test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an decreasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term, @@ -116,13 +133,14 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { } test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { - val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) val times = List.tabulate(50)(x => x * 20) // every 20ms val rng = new Random() - val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) val procDelayMs = 20 val proc = List.fill(50)(procDelayMs) // 20ms of processing - val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) @@ -131,7 +149,12 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { res(n) should not be None if (res(n).get > 0 && sched(n) > 0) { res(n).get should be < speeds(n) + res(n).get should be >= minRate } } } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 995f1197ccdf..34cd7435569e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -63,7 +63,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -75,7 +75,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -116,7 +117,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -128,20 +130,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost", "0") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost", "1") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost", "2") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) @@ -156,7 +158,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -173,8 +176,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // fulfill completedBatchInfos for(i <- 0 until limit) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) listener.onJobStart(jobStart) @@ -185,7 +188,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart) val batchInfoSubmitted = - BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None) + BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // We still can see the info retrieved from onJobStart @@ -201,8 +204,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // A lot of "onBatchCompleted"s happen before "onJobStart" for(i <- limit + 1 to limit * 2) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } @@ -227,11 +230,13 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -248,7 +253,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala new file mode 100644 index 000000000000..0544972d95c0 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ManualClock + +class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { + + test("basic") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-basic") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + clock.advance(100) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L)) + } + clock.advance(200) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L, 200L, 300L)) + } + assert(timer.stop(interruptTimer = true) === 300L) + } + + test("SPARK-10224: call 'callback' after stopping") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-SPARK-10224") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + @volatile var lastTime = -1L + // Now RecurringTimer is waiting for the next interval + val thread = new Thread { + override def run(): Unit = { + lastTime = timer.stop(interruptTimer = false) + } + } + thread.start() + val stopped = PrivateMethod[RecurringTimer]('stopped) + // Make sure the `stopped` field has been changed + eventually(timeout(10.seconds), interval(10.millis)) { + assert(timer.invokePrivate(stopped()) === true) + } + clock.advance(200) + // When RecurringTimer is awake from clock.waitTillTime, it will call `callback` once. + // Then it will find `stopped` is true and exit the loop, but it should call `callback` again + // before exiting its internal thread. + thread.join() + assert(results === Seq(0L, 100L, 200L)) + assert(lastTime === 200L) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 325ff7c74c39..ef1e89df3130 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,30 +18,46 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util +import java.util.{Iterator => JIterator} +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{eq => meq} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.BeforeAndAfter +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.mock.MockitoSugar -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{CompletionIterator, ThreadUtils, ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkFunSuite} -class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { +/** Common tests for WriteAheadLogs that we would like to test with different configurations. */ +abstract class CommonWriteAheadLogTests( + allowBatching: Boolean, + closeFileAfterWrite: Boolean, + testTag: String = "") + extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - val hadoopConf = new Configuration() - var tempDir: File = null - var testDir: String = null - var testFile: String = null - var writeAheadLog: FileBasedWriteAheadLog = null + protected val hadoopConf = new Configuration() + protected var tempDir: File = null + protected var testDir: String = null + protected var testFile: String = null + protected var writeAheadLog: WriteAheadLog = null + protected def testPrefix = if (testTag != "") testTag + " - " else testTag before { tempDir = Utils.createTempDir() @@ -57,47 +73,208 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogUtils - log selection and creation") { - val logDir = Utils.createTempDir().getAbsolutePath() + test(testPrefix + "read all logs") { + // Write data manually for testing reading through WriteAheadLog + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten + + val logDirectoryPath = new Path(testDir) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + assert(fileSystem.exists(logDirectoryPath) === true) + + // Read data using manager and verify + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === writtenData) + } + + test(testPrefix + "write logs") { + // Write data with rotation using WriteAheadLog class + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite, + allowBatching = allowBatching) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val writtenData = readAndDeserializeDataManually(logFiles, allowBatching) + assert(writtenData === dataToWrite) + } + + test(testPrefix + "read all logs after write") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(dataToWrite === readData) + } + + test(testPrefix + "clean old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test(testPrefix + "clean old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { + // Write data with manager, recover with new manager and verify + val manualClock = new ManualClock + val dataToWrite = generateRandomData() + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, + allowBatching, manualClock, closeLog = false) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + + if (waitForCompletion) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(Eventually.timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } + } + } + + test(testPrefix + "handling file errors while reading rotating logs") { + // Generate a set of log files + val manualClock = new ManualClock + val dataToWrite1 = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles1 = getLogFilesInDirectory(testDir) + assert(logFiles1.size > 1) + + + // Recover old files and generate a second set of log files + val dataToWrite2 = generateRandomData() + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching , + manualClock) + val logFiles2 = getLogFilesInDirectory(testDir) + assert(logFiles2.size > logFiles1.size) - def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + // Read the files and verify that all the written data can be read + val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + + // Corrupt the first set of files so that they are basically unreadable + logFiles1.foreach { f => + val raf = new FileOutputStream(f, true).getChannel() + raf.truncate(1) + raf.close() + } + + // Verify that the corrupted files do not prevent reading of the second set of data + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === dataToWrite2) + } + + test(testPrefix + "do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + if (allowBatching) { + intercept[UnsupportedOperationException](wal.read(writtenSegment.head)) + } else { + wal.read(writtenSegment.head) } + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } + + test(testPrefix + "parallel recovery not enabled if closeFileAfterWrite = false") { + // write some data + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten - def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + // create iterator but don't materialize it + val readData = wal.readAll().asScala.map(byteBufferToString) + wal.close() + if (closeFileAfterWrite) { + // the threadpool is shutdown by the wal.close call above, therefore we shouldn't be able + // to materialize the iterator with parallel recovery + intercept[RejectedExecutionException](readData.toArray) + } else { + assert(readData.toSeq === writtenData) } + } +} + +class FileBasedWriteAheadLogSuite + extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") { - val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) - assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) - - // Verify setting driver WAL class - val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](conf1) - assertReceiverLogClass[FileBasedWriteAheadLog](conf1) - - // Verify setting receiver WAL class - val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) - assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) - - // Verify setting receiver WAL class with 1-arg constructor - val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog1].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) - - // Verify failure setting receiver WAL class with 2-arg constructor - intercept[SparkException] { - val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog2].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + import WriteAheadLogSuite._ + + test("FileBasedWriteAheadLog - seqToParIterator") { + /* + If the setting `closeFileAfterWrite` is enabled, we start generating a very large number of + files. This causes recovery to take a very long time. In order to make it quicker, we + parallelized the reading of these files. This test makes sure that we limit the number of + open files to the size of the number of threads in our thread pool rather than the size of + the list of files. + */ + val numThreads = 8 + val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool") + class GetMaxCounter { + private val value = new AtomicInteger() + @volatile private var max: Int = 0 + def increment(): Unit = synchronized { + val atInstant = value.incrementAndGet() + if (atInstant > max) max = atInstant + } + def decrement(): Unit = synchronized { value.decrementAndGet() } + def get(): Int = synchronized { value.get() } + def getMax(): Int = synchronized { max } + } + try { + // If Jenkins is slow, we may not have a chance to run many threads simultaneously. Having + // a latch will make sure that all the threads can be launched altogether. + val latch = new CountDownLatch(1) + val testSeq = 1 to 1000 + val counter = new GetMaxCounter() + def handle(value: Int): Iterator[Int] = { + new CompletionIterator[Int, Iterator[Int]](Iterator(value)) { + counter.increment() + // block so that other threads also launch + latch.await(10, TimeUnit.SECONDS) + override def completion() { counter.decrement() } + } + } + @volatile var collected: Seq[Int] = Nil + val t = new Thread() { + override def run() { + // run the calculation on a separate thread so that we can release the latch + val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle) + collected = iterator.toSeq + } + } + t.start() + eventually(Eventually.timeout(10.seconds)) { + // make sure we are doing a parallel computation! + assert(counter.getMax() > 1) + } + latch.countDown() + t.join(10000) + assert(collected === testSeq) + // make sure we didn't open too many Iterators + assert(counter.getMax() <= numThreads) + } finally { + tpool.shutdownNow() } } @@ -121,7 +298,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() - writeDataManually(writtenData, testFile) + writeDataManually(writtenData, testFile, allowBatching = false) val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) @@ -162,10 +339,30 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } + test("FileBasedWriteAheadLogReader - handles errors when file doesn't exist") { + // Write data manually for testing the sequential reader + val dataToWrite = generateRandomData() + writeDataUsingWriter(testFile, dataToWrite) + val tFile = new File(testFile) + assert(tFile.exists()) + // Verify the data can be read and is same as the one correctly written + assert(readDataUsingReader(testFile) === dataToWrite) + + tFile.delete() + assert(!tFile.exists()) + + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) + assert(!reader.hasNext) + reader.close() + + // Verify that no exception is thrown if file doesn't exist + assert(readDataUsingReader(testFile) === Nil) + } + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() - val segments = writeDataManually(writtenData, testFile) + val segments = writeDataManually(writtenData, testFile, allowBatching = false) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten @@ -189,148 +386,219 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { } reader.close() } +} + +abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String) + extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) { - test("FileBasedWriteAheadLog - write rotating logs") { + import WriteAheadLogSuite._ + test(testPrefix + "close after write flag") { // Write data with rotation using WriteAheadLog class - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) + val numFiles = 3 + val dataToWrite = Seq.tabulate(numFiles)(_.toString) + // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, + closeFileAfterWrite = true, allowBatching = allowBatching) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val writtenData = logFiles.flatMap { file => readDataManually(file)} + assert(logFiles.size === numFiles) + val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching) assert(writtenData === dataToWrite) } +} - test("FileBasedWriteAheadLog - read rotating logs") { - // Write data manually for testing reading through WriteAheadLog - val writtenData = (1 to 10).map { i => - val data = generateRandomData() - val file = testDir + s"/log-$i-$i" - writeDataManually(data, file) - data - }.flatten +class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog") - val logDirectoryPath = new Path(testDir) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - assert(fileSystem.exists(logDirectoryPath) === true) +class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( + allowBatching = true, + closeFileAfterWrite = false, + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { - // Read data using manager and verify - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === writtenData) - } + import BatchedWriteAheadLog._ + import WriteAheadLogSuite._ - test("FileBasedWriteAheadLog - recover past logs when creating new manager") { - // Write data with manager, recover with new manager and verify - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val readData = readDataUsingWriteAheadLog(testDir) - assert(dataToWrite === readData) - } + private var wal: WriteAheadLog = _ + private var walHandle: WriteAheadLogRecordHandle = _ + private var walBatchingThreadPool: ThreadPoolExecutor = _ + private var walBatchingExecutionContext: ExecutionContextExecutorService = _ + private val sparkConf = new SparkConf() - test("FileBasedWriteAheadLog - clean old logs") { - logCleanUpTest(waitForCompletion = false) - } + private val queueLength = PrivateMethod[Int]('getQueueLength) - test("FileBasedWriteAheadLog - clean old logs synchronously") { - logCleanUpTest(waitForCompletion = true) + override def beforeEach(): Unit = { + wal = mock[WriteAheadLog] + walHandle = mock[WriteAheadLogRecordHandle] + walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") + walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool) } - private def logCleanUpTest(waitForCompletion: Boolean): Unit = { - // Write data with manager, recover with new manager and verify - val manualClock = new ManualClock - val dataToWrite = generateRandomData() - writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - - writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) - - if (waitForCompletion) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } else { - eventually(timeout(1 second), interval(10 milliseconds)) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } + override def afterEach(): Unit = { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() } } - test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { - // Generate a set of log files - val manualClock = new ManualClock - val dataToWrite1 = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) - val logFiles1 = getLogFilesInDirectory(testDir) - assert(logFiles1.size > 1) + test("BatchedWriteAheadLog - serializing and deserializing batched records") { + val events = Seq( + BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)), + BatchAllocationEvent(null, null), + BatchCleanupEvent(Nil) + ) + val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null)) + val batched = BatchedWriteAheadLog.aggregate(buffers) + val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array())) - // Recover old files and generate a second set of log files - val dataToWrite2 = generateRandomData() - manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) - val logFiles2 = getLogFilesInDirectory(testDir) - assert(logFiles2.size > logFiles1.size) + assert(deaggregate.toSeq === events) + } - // Read the files and verify that all the written data can be read - val readData1 = readDataUsingWriteAheadLog(testDir) - assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") { + when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!")) + // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // Corrupt the first set of files so that they are basically unreadable - logFiles1.foreach { f => - val raf = new FileOutputStream(f, true).getChannel() - raf.truncate(1) - raf.close() + intercept[RuntimeException] { + val buffer = mock[ByteBuffer] + batchedWal.write(buffer, 2L) } - - // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === dataToWrite2) } - test("FileBasedWriteAheadLog - do not create directories or files unless write") { - val nonexistentTempPath = File.createTempFile("test", "") - nonexistentTempPath.delete() - assert(!nonexistentTempPath.exists()) + // we make the write requests in separate threads so that we don't block the test thread + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + val p = Promise[Unit]() + p.completeWith(Future { + val v = wal.write(event, time) + assert(v === walHandle) + }(walBatchingExecutionContext)) + p + } - val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog( - new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1) - assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") - wal.read(writtenSegment.head) - assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + val event4 = "is" + val event5 = "doge" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + // we would like event 5 to be written before event 4 in order to test that they get + // sorted before being aggregated + writeAsync(batchedWal, event5, 12L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 3) + } + writeAsync(batchedWal, event4, 10L) + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) + } + blockingWal.allowWrite() + + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) + + eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) + verify(wal, times(1)).write(meq(buffer), meq(3L)) + // the file name should be the timestamp of the last record, as events should be naturally + // in order of timestamp, and we need the last element. + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) + } } -} -object WriteAheadLogSuite { + test("BatchedWriteAheadLog - shutdown properly") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + batchedWal.close() + verify(wal, times(1)).close() - class MockWriteAheadLog0() extends WriteAheadLog { - override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } - override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } - override def readAll(): util.Iterator[ByteBuffer] = { null } - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } - override def close(): Unit = { } + intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L)) } - class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + test("BatchedWriteAheadLog - fail everything in queue during shutdown") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" - class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) + + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written + } + + val writePromises = Seq(promise1, promise2, promise3) + + batchedWal.close() + eventually(timeout(1 second)) { + assert(writePromises.forall(_.isCompleted)) + assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed + } + } +} +class BatchedWriteAheadLogWithCloseFileAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog") + +object WriteAheadLogSuite { private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + def writeDataManually( + data: Seq[String], + file: String, + allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = { val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) - data.foreach { item => + def writeToStream(bytes: Array[Byte]): Unit = { val offset = writer.getPos - val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } + if (allowBatching) { + writeToStream(wrapArrayArrayByte(data.toArray[String]).array()) + } else { + data.foreach { item => + writeToStream(Utils.serialize(item)) + } + } writer.close() segments } @@ -340,8 +608,7 @@ object WriteAheadLogSuite { */ def writeDataUsingWriter( filePath: String, - data: Seq[String] - ): Seq[FileBasedWriteAheadLogSegment] = { + data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = { val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) @@ -354,15 +621,17 @@ object WriteAheadLogSuite { def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], + closeFileAfterWrite: Boolean, + allowBatching: Boolean, manualClock: ManualClock = new ManualClock, - closeLog: Boolean = true - ): FileBasedWriteAheadLog = { + closeLog: Boolean = true, + clockAdvanceTime: Int = 500): WriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => - manualClock.advance(500) + manualClock.advance(clockAdvanceTime) wal.write(item, manualClock.getTimeMillis()) } if (closeLog) wal.close() @@ -388,16 +657,16 @@ object WriteAheadLogSuite { } /** Read all the data from a log file directly and return the list of byte buffers. */ - def readDataManually(file: String): Seq[String] = { + def readDataManually[T](file: String): Seq[T] = { val reader = HdfsUtils.getInputStream(file, hadoopConf) - val buffer = new ArrayBuffer[String] + val buffer = new ArrayBuffer[T] try { while (true) { // Read till EOF is thrown val length = reader.readInt() val bytes = new Array[Byte](length) reader.read(bytes) - buffer += Utils.deserialize[String](bytes) + buffer += Utils.deserialize[T](bytes) } } catch { case ex: EOFException => @@ -416,15 +685,17 @@ object WriteAheadLogSuite { } /** Read all the data in the log file in a directory using the WriteAheadLog class. */ - def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - import scala.collection.JavaConversions._ - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - val data = wal.readAll().map(byteBufferToString).toSeq + def readDataUsingWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): Seq[String] = { + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) + val data = wal.readAll().asScala.map(byteBufferToString).toArray wal.close() data } - /** Get the log files in a direction */ + /** Get the log files in a directory. */ def getLogFilesInDirectory(directory: String): Seq[String] = { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) @@ -440,10 +711,31 @@ object WriteAheadLogSuite { } } + def createWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): WriteAheadLog = { + val sparkConf = new SparkConf + val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) + if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal + } + def generateRandomData(): Seq[String] = { (1 to 100).map { _.toString } } + def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = { + if (allowBatching) { + logFiles.flatMap { file => + val data = readDataManually[Array[Array[Byte]]](file) + data.flatMap(byteArray => byteArray.map(Utils.deserialize[String])) + } + } else { + logFiles.flatMap { file => readDataManually[String](file)} + } + } + implicit def stringToByteBuffer(str: String): ByteBuffer = { ByteBuffer.wrap(Utils.serialize(str)) } @@ -451,4 +743,41 @@ object WriteAheadLogSuite { implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { Utils.deserialize[String](byteBuffer.array) } + + def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) + } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala new file mode 100644 index 000000000000..bfc5b0cf60fb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils + +class WriteAheadLogUtilsSuite extends SparkFunSuite { + import WriteAheadLogUtilsSuite._ + + private val logDir = Utils.createTempDir().getAbsolutePath() + private val hadoopConf = new Configuration() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag]( + conf: SparkConf, + isBatched: Boolean = false): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + if (isBatched) { + assert(log.isInstanceOf[BatchedWriteAheadLog]) + val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog + assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass) + } else { + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + } + log + } + + def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + test("log selection and creation") { + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") { + def getBatchedSparkConf: SparkConf = + new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true") + + val justBatchingConf = getBatchedSparkConf + assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf) + + // Verify setting driver WAL class + val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify receivers are not wrapped + val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + } + + test("batching is enabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) + // batching is not valid for receiver WALs + assert(!WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = false)) + } + + test("closeFileAfterWrite is disabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = true)) + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = false)) + } +} + +object WriteAheadLogUtilsSuite { + + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +} diff --git a/tags/pom.xml b/tags/pom.xml new file mode 100644 index 000000000000..ca93722e7334 --- /dev/null +++ b/tags/pom.xml @@ -0,0 +1,50 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-test-tags_2.10 + jar + Spark Project Test Tags + http://spark.apache.org/ + + test-tags + + + + + org.scalatest + scalatest_${scala.binary.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/tags/src/main/java/org/apache/spark/tags/DockerTest.java new file mode 100644 index 000000000000..0fecf3b8f979 --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/DockerTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface DockerTest { } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java new file mode 100644 index 000000000000..83279e5e93c0 --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; + +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java new file mode 100644 index 000000000000..108300168e17 --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; + +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/tools/pom.xml b/tools/pom.xml index 298ee2348b58..1e64f280e5be 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 9418beb6b3e3..5155daa6d17b 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -22,7 +22,7 @@ import java.io.File import java.util.jar.JarFile import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} import scala.util.Try @@ -72,7 +72,9 @@ object GenerateMIMAIgnore { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = - isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) + isPackagePrivate(classSymbol) || + isPackagePrivateModule(moduleSymbol) || + classSymbol.isPrivate val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) /* Inner classes defined within a private[spark] class or object are effectively @@ -161,7 +163,7 @@ object GenerateMIMAIgnore { val path = packageName.replace('.', '/') val resources = classLoader.getResources(path) - val jars = resources.filter(x => x.getProtocol == "jar") + val jars = resources.asScala.filter(_.getProtocol == "jar") .map(_.getFile.split(":")(1).split("!")(0)).toSeq jars.flatMap(getClassesFromJar(_, path)) @@ -175,7 +177,7 @@ object GenerateMIMAIgnore { private def getClassesFromJar(jarPath: String, packageName: String) = { import scala.collection.mutable val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) + val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) val classes = mutable.HashSet[Class[_]]() for (entry <- enums if entry.endsWith(".class")) { try { diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 89475ee3cf5a..a1c1111364ee 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -36,6 +36,10 @@ + + com.twitter + chill_${scala.binary.version} + @@ -56,14 +60,8 @@ - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} org.mockito diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java new file mode 100644 index 000000000000..0d6b215fe5aa --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java deleted file mode 100644 index 192c6714b240..000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe; - -import java.lang.reflect.Field; - -import sun.misc.Unsafe; - -public final class PlatformDependent { - - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { - - private UNSAFE() { } - - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } - - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } - - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } - - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } - - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } - - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } - - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } - - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } - - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } - - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } - - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } - - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } - - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } - - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } - - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } - - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } - - } - - private static final Unsafe _UNSAFE; - - public static final int BYTE_ARRAY_OFFSET; - - public static final int INT_ARRAY_OFFSET; - - public static final int LONG_ARRAY_OFFSET; - - public static final int DOUBLE_ARRAY_OFFSET; - - /** - * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to - * allow safepoint polling during a large copy. - */ - private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; - - static { - sun.misc.Unsafe unsafe; - try { - Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); - unsafeField.setAccessible(true); - unsafe = (sun.misc.Unsafe) unsafeField.get(null); - } catch (Throwable cause) { - unsafe = null; - } - _UNSAFE = unsafe; - - if (_UNSAFE != null) { - BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); - INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); - LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); - DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); - } else { - BYTE_ARRAY_OFFSET = 0; - INT_ARRAY_OFFSET = 0; - LONG_ARRAY_OFFSET = 0; - DOUBLE_ARRAY_OFFSET = 0; - } - } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf693d01a4f5..cf42877bf9fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import static org.apache.spark.unsafe.PlatformDependent.*; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -25,6 +25,12 @@ private ByteArrayMethods() { // Private constructor, since this class only contains static methods. } + /** Returns the next number greater or equal num that is power of 2. */ + public static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } + public static int roundNumberOfBytesToNearestWord(int numBytes) { int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { @@ -39,20 +45,18 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, - long leftOffset, - Object rightBase, - long rightOffset, - final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { - if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { return false; } i += 8; } while (i < length) { - if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { return false; } i += 1; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb..1a3cdff63826 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -39,7 +39,6 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); @@ -51,6 +50,14 @@ public MemoryBlock memoryBlock() { return memory; } + public Object getBaseObject() { + return baseObj; + } + + public long getBaseOffset() { + return baseOffset; + } + /** * Returns the number of elements this array can hold. */ @@ -58,13 +65,22 @@ public long size() { return length; } + /** + * Fill this all with 0L. + */ + public void zeroOut() { + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); + } + } + /** * Sets the value at position {@code index}. */ public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +89,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java deleted file mode 100644 index 7c124173b0bb..000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.bitset; - -import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; - -/** - * A fixed size uncompressed bit set backed by a {@link LongArray}. - * - * Each bit occupies exactly one bit of storage. - */ -public final class BitSet { - - /** A long array for the bits. */ - private final LongArray words; - - /** Length of the long array. */ - private final int numWords; - - private final Object baseObject; - private final long baseOffset; - - /** - * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be - * multiple of 8 bytes (i.e. 64 bits). - */ - public BitSet(MemoryBlock memory) { - words = new LongArray(memory); - assert (words.size() <= Integer.MAX_VALUE); - numWords = (int) words.size(); - baseObject = words.memoryBlock().getBaseObject(); - baseOffset = words.memoryBlock().getBaseOffset(); - } - - public MemoryBlock memoryBlock() { - return words.memoryBlock(); - } - - /** - * Returns the number of bits in this {@code BitSet}. - */ - public long capacity() { - return numWords * 64; - } - - /** - * Sets the bit at the specified index to {@code true}. - */ - public void set(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - BitSetMethods.set(baseObject, baseOffset, index); - } - - /** - * Sets the bit at the specified index to {@code false}. - */ - public void unset(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - BitSetMethods.unset(baseObject, baseOffset, index); - } - - /** - * Returns {@code true} if the bit is set at the specified index. - */ - public boolean isSet(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - return BitSetMethods.isSet(baseObject, baseOffset, index); - } - - /** - * Returns the index of the first bit that is set to true that occurs on or after the - * specified starting index. If no such bit exists then {@code -1} is returned. - *

        - * To iterate over the true bits in a BitSet, use the following loop: - *

        -   * 
        -   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
        -   *    // operate on index i here
        -   *  }
        -   * 
        -   * 
        - * - * @param fromIndex the index to start checking from (inclusive) - * @return the index of the next set bit, or -1 if there is no such bit - */ - public int nextSetBit(int fromIndex) { - return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); - } - - /** - * Returns {@code true} if any bit is set. - */ - public boolean anySet() { - return BitSetMethods.anySet(baseObject, baseOffset, numWords); - } - -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e6..7857bf66a72a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Methods for working with fixed-size uncompressed bitsets. @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); } /** @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); } /** @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + final long word = Platform.getLong(baseObject, wordOffset); return (word & mask) != 0; } @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { - if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { + if (Platform.getLong(baseObject, addr) != 0) { return true; } } @@ -109,8 +109,7 @@ public static int nextSetBit( // Try to find the next set bit in the current word final int subIndex = fromIndex & 0x3f; - long word = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; if (word != 0) { return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); } @@ -118,7 +117,7 @@ public static int nextSetBit( // Find the next set bit in the rest of the words wi += 1; while (wi < bitsetSizeInWords) { - word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); if (word != 0) { return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 61f483ced321..4276f25c2165 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,7 +53,7 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java deleted file mode 100644 index cbbe8594627a..000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import java.lang.ref.WeakReference; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import javax.annotation.concurrent.GuardedBy; - -/** - * Manages memory for an executor. Individual operators / tasks allocate memory through - * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. - */ -public class ExecutorMemoryManager { - - /** - * Allocator, exposed for enabling untracked allocations of temporary data structures. - */ - public final MemoryAllocator allocator; - - /** - * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. - */ - final boolean inHeap; - - @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap>>(); - - private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; - - /** - * Construct a new ExecutorMemoryManager. - * - * @param allocator the allocator that will be used - */ - public ExecutorMemoryManager(MemoryAllocator allocator) { - this.inHeap = allocator instanceof HeapMemoryAllocator; - this.allocator = allocator; - } - - /** - * Returns true if allocations of the given size should go through the pooling mechanism and - * false otherwise. - */ - private boolean shouldPool(long size) { - // Very small allocations are less likely to benefit from pooling. - // At some point, we should explore supporting pooling for off-heap memory, but for now we'll - // ignore that case in the interest of simplicity. - return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). - */ - MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { - synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); - if (pool != null) { - while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); - return memory; - } - } - bufferPoolsBySize.remove(size); - } - } - return allocator.allocate(size); - } else { - return allocator.allocate(size); - } - } - - void free(MemoryBlock memory) { - final long size = memory.size(); - if (shouldPool(size)) { - synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); - if (pool == null) { - pool = new LinkedList>(); - bufferPoolsBySize.put(size, pool); - } - pool.add(new WeakReference(memory)); - } - } else { - allocator.free(memory); - } - } - -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index bbe83d36cf36..09847cec9c4c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -17,19 +17,70 @@ package org.apache.spark.unsafe.memory; +import javax.annotation.concurrent.GuardedBy; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; + +import org.apache.spark.unsafe.Platform; + /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ public class HeapMemoryAllocator implements MemoryAllocator { + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap<>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + return size >= POOLING_THRESHOLD_BYTES; + } + @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - long[] array = new long[(int) (size / 8)]; - return MemoryBlock.fromLongArray(array); + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + } + long[] array = new long[(int) ((size + 7) / 8)]; + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); } @Override public void free(MemoryBlock memory) { - // Do nothing + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList<>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference<>(memory)); + } + } else { + // Do nothing + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index 91be46ba21ff..e3e79471154d 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -19,7 +19,7 @@ import javax.annotation.Nullable; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. @@ -30,9 +30,10 @@ public class MemoryBlock extends MemoryLocation { /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * MemoryManager. This is package-private and is modified by MemoryManager. + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. */ - int pageNumber = -1; + public int pageNumber = -1; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); @@ -50,6 +51,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 15898771fef2..98ce711176e4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.memory; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. @@ -26,7 +26,7 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - long address = PlatformDependent.UNSAFE.allocateMemory(size); + long address = Platform.allocateMemory(size); return new MemoryBlock(null, address, size); } @@ -34,6 +34,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - PlatformDependent.UNSAFE.freeMemory(memory.offset); + Platform.freeMemory(memory.offset); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 69b0e206cef1..3ced2094f5e6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,9 +17,13 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; -public class ByteArray { +import java.util.Arrays; + +public final class ByteArray { + + public static final byte[] EMPTY_BYTE = new byte[0]; /** * Writes the content of a byte array into a memory address, identified by an object and an @@ -27,12 +31,47 @@ public class ByteArray { * hold all the bytes in this string. */ public static void writeToMemory(byte[] src, Object target, long targetOffset) { - PlatformDependent.copyMemory( - src, - PlatformDependent.BYTE_ARRAY_OFFSET, - target, - targetOffset, - src.length - ); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); + } + + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public static long getPrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + final int minLen = Math.min(bytes.length, 8); + long p = 0; + for (int i = 0; i < minLen; ++i) { + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) + << (56 - 8 * i); + } + return p; + } + } + + public static byte[] subStringSQL(byte[] bytes, int pos, int len) { + // This pos calculation is according to UTF8String#subStringSQL + if (pos > bytes.length) { + return EMPTY_BYTE; + } + int start = 0; + int end; + if (pos > 0) { + start = pos - 1; + } else if (pos < 0) { + start = bytes.length + pos; + } + if ((bytes.length - start) < len) { + end = bytes.length; + } else { + end = start + len; + } + start = Math.max(start, 0); // underflow + if (start >= end) { + return EMPTY_BYTE; + } + return Arrays.copyOfRange(bytes, start, end); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 92a5e4f86f23..30e175807636 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -50,6 +50,14 @@ private static String unitRegex(String unit) { unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + private static long toLong(String s) { if (s == null) { return 0; @@ -79,6 +87,154 @@ public static CalendarInterval fromString(String s) { } } + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.valueOf(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + if (unit.equals("year")) { + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + + } else if (unit.equals("month")) { + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + + } else if (unit.equals("day")) { + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + + } else if (unit.equals("hour")) { + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + + } else if (unit.equals("minute")) { + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + + } else if (unit.equals("second")) { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + public final int months; public final long microseconds; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f6c9b87778f8..5b6138680876 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -18,14 +18,21 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; -import java.io.Serializable; -import java.io.UnsupportedEncodingException; +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; +import java.util.Map; -import org.apache.spark.unsafe.PlatformDependent; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import static org.apache.spark.unsafe.PlatformDependent.*; +import static org.apache.spark.unsafe.Platform.*; /** @@ -36,12 +43,13 @@ *

        * Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Serializable { +public final class UTF8String implements Comparable, Externalizable, KryoSerializable { + // These are only updated by readExternal() or read() @Nonnull - private final Object base; - private final long offset; - private final int numBytes; + private Object base; + private long offset; + private int numBytes; public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -53,6 +61,9 @@ public final class UTF8String implements Comparable, Serializable { 5, 5, 5, 5, 6, 6}; + private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); /** @@ -85,11 +96,7 @@ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { * Creates an UTF8String from given address (base and offset) and length. */ public static UTF8String fromAddress(Object base, long offset, int numBytes) { - if (base != null) { - return new UTF8String(base, offset, numBytes); - } else { - return null; - } + return new UTF8String(base, offset, numBytes); } /** @@ -122,19 +129,27 @@ protected UTF8String(Object base, long offset, int numBytes) { this.numBytes = numBytes; } + // for serialization + public UTF8String() { + this(null, 0, 0); + } + /** * Writes the content of this string into a memory address, identified by an object and an offset. * The target memory address must already been allocated, and have enough space to hold all the * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - base, - offset, - target, - targetOffset, - numBytes - ); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + numBytes); } /** @@ -175,18 +190,35 @@ public long getPrefix() { // If size is greater than 4, assume we have at least 8 bytes of data to fetch. // After getting the data, we use a mask to mask out data that is not part of the string. long p; - if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - p = p & ((1L << numBytes * 8) - 1); - } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); - p = p & ((1L << numBytes * 8) - 1); + long mask = 0; + if (isLittleEndian) { + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) Platform.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); } else { - p = 0; + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) Platform.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } } - p = java.lang.Long.reverseBytes(p); + p &= ~mask; return p; } @@ -271,7 +303,7 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return UNSAFE.getByte(base, offset + i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { @@ -391,6 +423,36 @@ private UTF8String toTitleCaseSlow() { return fromString(sb.toString()); } + /* + * Returns the index of the string `match` in this String. This string has to be a comma separated + * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, + * 0 will be returned, else the index of match (1-based index) + */ + public int findInSet(UTF8String match) { + if (match.contains(COMMA_UTF8)) { + return 0; + } + + int n = 1, lastComma = -1; + for (int i = 0; i < numBytes; i++) { + if (getByte(i) == (byte) ',') { + if (i - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + lastComma = i; + n++; + } + } + if (numBytes - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + return 0; + } + /** * Copy the bytes from the current UTF8String, and make a new UTF8String. * @param start the start position of the current UTF8String in bytes. @@ -717,7 +779,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { int len = inputs[i].numBytes; copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -726,7 +788,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { if (j < numInputs) { copyMemory( separator.base, separator.offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; } @@ -744,6 +806,21 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes + public UTF8String translate(Map dict) { + String srcStr = this.toString(); + + StringBuilder sb = new StringBuilder(); + for(int k = 0; k< srcStr.length(); k++) { + if (null == dict.get(srcStr.charAt(k))) { + sb.append(srcStr.charAt(k)); + } else if ('\0' != dict.get(srcStr.charAt(k))){ + sb.append(dict.get(srcStr.charAt(k))); + } + } + return fromString(sb.toString()); + } + @Override public String toString() { try { @@ -823,9 +900,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -888,7 +965,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; @@ -917,4 +994,33 @@ public UTF8String soundex() { } return UTF8String.fromBytes(sx); } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); + } + + @Override + public void write(Kryo kryo, Output out) { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + @Override + public void read(Kryo kryo, Input in) { + this.offset = BYTE_ARRAY_OFFSET; + this.numBytes = in.readInt(); + this.base = new byte[numBytes]; + in.read((byte[]) base); + } + } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java new file mode 100644 index 000000000000..693ec6ec58db --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +public class PlatformUtilSuite { + + @Test + public void overlappingCopyMemory() { + byte[] data = new byte[3 * 1024 * 1024]; + int size = 2 * 1024 * 1024; + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size); + for (int i = 0; i < data.length; ++i) { + Assert.assertEquals((byte)i, data[i]); + } + + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET + 1, + data, + Platform.BYTE_ARRAY_OFFSET, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)(i + 1), data[i]); + } + + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET, + data, + Platform.BYTE_ARRAY_OFFSET + 1, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)i, data[i + 1]); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 5974cf91ff99..fb8e53b3348f 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -34,5 +34,9 @@ public void basicTest() { Assert.assertEquals(2, arr.size()); Assert.assertEquals(1L, arr.get(0)); Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java deleted file mode 100644 index a93fc0ee297c..000000000000 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.bitset; - -import junit.framework.Assert; -import org.junit.Test; - -import org.apache.spark.unsafe.memory.MemoryBlock; - -public class BitSetSuite { - - private static BitSet createBitSet(int capacity) { - assert capacity % 64 == 0; - return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); - } - - @Test - public void basicOps() { - BitSet bs = createBitSet(64); - Assert.assertEquals(64, bs.capacity()); - - // Make sure the bit set starts empty. - for (int i = 0; i < bs.capacity(); i++) { - Assert.assertFalse(bs.isSet(i)); - } - // another form of asserting that the bit set is empty - Assert.assertFalse(bs.anySet()); - - // Set every bit and check it. - for (int i = 0; i < bs.capacity(); i++) { - bs.set(i); - Assert.assertTrue(bs.isSet(i)); - } - - // Unset every bit and check it. - for (int i = 0; i < bs.capacity(); i++) { - Assert.assertTrue(bs.isSet(i)); - bs.unset(i); - Assert.assertFalse(bs.isSet(i)); - } - - // Make sure anySet() can detect any set bit - bs = createBitSet(256); - bs.set(64); - Assert.assertTrue(bs.anySet()); - } - - @Test - public void traversal() { - BitSet bs = createBitSet(256); - - Assert.assertEquals(-1, bs.nextSetBit(0)); - Assert.assertEquals(-1, bs.nextSetBit(10)); - Assert.assertEquals(-1, bs.nextSetBit(64)); - - bs.set(10); - Assert.assertEquals(10, bs.nextSetBit(0)); - Assert.assertEquals(10, bs.nextSetBit(1)); - Assert.assertEquals(10, bs.nextSetBit(10)); - Assert.assertEquals(-1, bs.nextSetBit(11)); - - bs.set(11); - Assert.assertEquals(10, bs.nextSetBit(10)); - Assert.assertEquals(11, bs.nextSetBit(11)); - - // Skip a whole word and find it - bs.set(190); - Assert.assertEquals(190, bs.nextSetBit(12)); - - Assert.assertEquals(-1, bs.nextSetBit(191)); - Assert.assertEquals(-1, bs.nextSetBit(256)); - } -} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229..e759cb33b3e6 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -17,12 +17,13 @@ package org.apache.spark.unsafe.hash; +import java.nio.charset.StandardCharsets; import java.util.HashSet; import java.util.Random; import java.util.Set; -import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; import org.junit.Test; /** @@ -56,7 +57,7 @@ public void randomizedStressTest() { Random rand = new Random(); // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int vint = rand.nextInt(); long lint = rand.nextLong(); @@ -76,18 +77,18 @@ public void randomizedStressTestBytes() { Random rand = new Random(); // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -98,19 +99,19 @@ public void randomizedStressTestBytes() { public void randomizedStressTestPaddedStrings() { int size = 64000; // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int byteArrSize = 8; - byte[] strBytes = ("" + i).getBytes(); + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java deleted file mode 100644 index 06fb08118365..000000000000 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.junit.Assert; -import org.junit.Test; - -public class TaskMemoryManagerSuite { - - @Test - public void leakedNonPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - manager.allocate(1024); // leak memory - Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); - } - - @Test - public void leakedPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - manager.allocatePage(4096); // leak memory - Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); - } - - @Test - public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock dataPage = manager.allocatePage(256); - // In off-heap mode, an offset is an absolute address that may require more than 51 bits to - // encode. This test exercises that corner-case: - final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); - Assert.assertEquals(null, manager.getPage(encodedAddress)); - Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); - } - - @Test - public void encodePageNumberAndOffsetOnHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = manager.allocatePage(256); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); - Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); - Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); - } - -} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 6274b92b47dd..9e69e264ff28 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -19,7 +19,7 @@ import org.junit.Test; -import static junit.framework.Assert.*; +import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.CalendarInterval.*; public class CalendarIntervalSuite { @@ -42,19 +42,19 @@ public void toStringTest() { CalendarInterval i; i = new CalendarInterval(34, 0); - assertEquals(i.toString(), "interval 2 years 10 months"); + assertEquals("interval 2 years 10 months", i.toString()); i = new CalendarInterval(-34, 0); - assertEquals(i.toString(), "interval -2 years -10 months"); + assertEquals("interval -2 years -10 months", i.toString()); i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + assertEquals("interval 3 weeks 13 hours 123 microseconds", i.toString()); i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); - assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + assertEquals("interval -3 weeks -13 hours -123 microseconds", i.toString()); i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + assertEquals("interval 2 years 10 months 3 weeks 13 hours 123 microseconds", i.toString()); } @Test @@ -73,32 +73,123 @@ public void fromStringTest() { input = "interval -5 years 23 month"; CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); input = "interval -5 years 23 month "; - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); input = " interval -5 years 23 month "; - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); // Error cases input = "interval 3month 1 hour"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "interval 3 moth 1 hour"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "interval"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "int"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = ""; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = null; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); + } + + @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(fromYearMonthString(input), i); + + try { + input = "99-15"; + fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } } @Test @@ -106,16 +197,16 @@ public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - CalendarInterval interval = CalendarInterval.fromString(input); - CalendarInterval interval2 = CalendarInterval.fromString(input2); + CalendarInterval interval = fromString(input); + CalendarInterval interval2 = fromString(input2); assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = CalendarInterval.fromString(input); - interval2 = CalendarInterval.fromString(input2); + interval = fromString(input); + interval2 = fromString(input2); assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @@ -125,25 +216,25 @@ public void subtractTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - CalendarInterval interval = CalendarInterval.fromString(input); - CalendarInterval interval2 = CalendarInterval.fromString(input2); + CalendarInterval interval = fromString(input); + CalendarInterval interval2 = fromString(input2); assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = CalendarInterval.fromString(input); - interval2 = CalendarInterval.fromString(input2); + interval = fromString(input); + interval2 = fromString(input2); assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } - private void testSingleUnit(String unit, int number, int months, long microseconds) { + private static void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; CalendarInterval result = new CalendarInterval(months, microseconds); - assertEquals(CalendarInterval.fromString(input1), result); - assertEquals(CalendarInterval.fromString(input2), result); + assertEquals(fromString(input1), result); + assertEquals(fromString(input2), result); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b3190f8f0c3..e21ffdcff9ab 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -19,16 +19,18 @@ import java.io.UnsupportedEncodingException; import java.util.Arrays; +import java.util.HashMap; +import com.google.common.collect.ImmutableMap; import org.junit.Test; -import static junit.framework.Assert.*; +import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.UTF8String.*; public class UTF8StringSuite { - private void checkBasic(String str, int len) throws UnsupportedEncodingException { + private static void checkBasic(String str, int len) throws UnsupportedEncodingException { UTF8String s1 = fromString(str); UTF8String s2 = fromBytes(str.getBytes("utf8")); assertEquals(s1.numChars(), len); @@ -40,12 +42,12 @@ private void checkBasic(String str, int len) throws UnsupportedEncodingException assertEquals(s1.hashCode(), s2.hashCode()); - assertEquals(s1.compareTo(s2), 0); + assertEquals(0, s1.compareTo(s2)); - assertEquals(s1.contains(s2), true); - assertEquals(s2.contains(s1), true); - assertEquals(s1.startsWith(s1), true); - assertEquals(s1.endsWith(s1), true); + assertTrue(s1.contains(s2)); + assertTrue(s2.contains(s1)); + assertTrue(s1.startsWith(s1)); + assertTrue(s1.endsWith(s1)); } @Test @@ -57,8 +59,8 @@ public void basicTest() throws UnsupportedEncodingException { @Test public void emptyStringTest() { - assertEquals(fromString(""), EMPTY_UTF8); - assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(EMPTY_UTF8, fromString("")); + assertEquals(EMPTY_UTF8, fromBytes(new byte[0])); assertEquals(0, EMPTY_UTF8.numChars()); assertEquals(0, EMPTY_UTF8.numBytes()); } @@ -74,9 +76,9 @@ public void prefix() { byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; byte[] buf2 = {1, 2, 3}; - UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); - UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); - UTF8String str3 = UTF8String.fromBytes(buf2); + UTF8String str1 = fromBytes(buf1, 0, 3); + UTF8String str2 = fromBytes(buf1, 0, 8); + UTF8String str3 = fromBytes(buf2); assertTrue(str1.getPrefix() - str2.getPrefix() < 0); assertEquals(str1.getPrefix(), str3.getPrefix()); } @@ -96,7 +98,7 @@ public void compareTo() { assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); } - protected void testUpperandLower(String upper, String lower) { + protected static void testUpperandLower(String upper, String lower) { UTF8String us = fromString(upper); UTF8String ls = fromString(lower); assertEquals(ls, us.toLowerCase()); @@ -125,22 +127,22 @@ public void titleCase() { @Test public void concatTest() { assertEquals(EMPTY_UTF8, concat()); - assertEquals(null, concat((UTF8String) null)); + assertNull(concat((UTF8String) null)); assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); assertEquals(fromString("ab"), concat(fromString("ab"))); assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); - assertEquals(null, concat(fromString("a"), null, fromString("c"))); - assertEquals(null, concat(fromString("a"), null, null)); - assertEquals(null, concat(null, null, null)); + assertNull(concat(fromString("a"), null, fromString("c"))); + assertNull(concat(fromString("a"), null, null)); + assertNull(concat(null, null, null)); assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); } @Test public void concatWsTest() { // Returns null if the separator is null - assertEquals(null, concatWs(null, (UTF8String)null)); - assertEquals(null, concatWs(null, fromString("a"))); + assertNull(concatWs(null, (UTF8String) null)); + assertNull(concatWs(null, fromString("a"))); // If separator is null, concatWs should skip all null inputs and never return null. UTF8String sep = fromString("哈哈"); @@ -379,16 +381,45 @@ public void split() { @Test public void levenshteinDistance() { - assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); - assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); - assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); - assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); - assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); - assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); - assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); - assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); - assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); - assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); + assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8)); + assertEquals(1, EMPTY_UTF8.levenshteinDistance(fromString("a"))); + assertEquals(7, fromString("aaapppp").levenshteinDistance(EMPTY_UTF8)); + assertEquals(1, fromString("frog").levenshteinDistance(fromString("fog"))); + assertEquals(3, fromString("fly").levenshteinDistance(fromString("ant"))); + assertEquals(7, fromString("elephant").levenshteinDistance(fromString("hippo"))); + assertEquals(7, fromString("hippo").levenshteinDistance(fromString("elephant"))); + assertEquals(8, fromString("hippo").levenshteinDistance(fromString("zzzzzzzz"))); + assertEquals(1, fromString("hello").levenshteinDistance(fromString("hallo"))); + assertEquals(4, fromString("世界千世").levenshteinDistance(fromString("千a世b"))); + } + + @Test + public void translate() { + assertEquals( + fromString("1a2s3ae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '1', + 'n', '2', + 'l', '3', + 't', '\0' + ))); + assertEquals( + fromString("translate"), + fromString("translate").translate(new HashMap())); + assertEquals( + fromString("asae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '\0', + 'n', '\0', + 'l', '\0', + 't', '\0' + ))); + assertEquals( + fromString("aa世b"), + fromString("花花世界").translate(ImmutableMap.of( + '花', 'a', + '界', 'b' + ))); } @Test @@ -399,6 +430,18 @@ public void createBlankString() { assertEquals(fromString(""), blankString(0)); } + @Test + public void findInSet() { + assertEquals(1, fromString("ab").findInSet(fromString("ab"))); + assertEquals(2, fromString("a,b").findInSet(fromString("b"))); + assertEquals(3, fromString("abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(1, fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(4, fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(1, fromString(",ab,abc,b,ab,c,def").findInSet(fromString(""))); + assertEquals(4, fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(6, fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def"))); + } + @Test public void soundex() { assertEquals(fromString("Robert").soundex(), fromString("R163")); diff --git a/yarn/pom.xml b/yarn/pom.xml index 2aeed98285aa..989b820bec9e 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -30,7 +30,6 @@ Spark Project YARN yarn - 1.9 @@ -39,6 +38,12 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-network-yarn_${scala.binary.version} + ${project.version} + test + org.apache.spark spark-core_${scala.binary.version} @@ -46,6 +51,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.apache.hadoop hadoop-yarn-api @@ -93,12 +102,28 @@ jetty-servlet - + + + + org.eclipse.jetty.orbit + javax.servlet.jsp + 2.2.0.v201112011158 + test + + + org.eclipse.jetty.orbit + javax.servlet.jsp.jstl + 1.2.0.v201105211821 + test + + - + org.apache.hadoop hadoop-yarn-server-tests @@ -125,29 +150,45 @@ com.sun.jersey jersey-core - ${jersey.version} test com.sun.jersey jersey-json - ${jersey.version} test - - - stax - stax-api - - com.sun.jersey jersey-server - ${jersey.version} test + + + + ${hive.group} + hive-exec + test + + + ${hive.group} + hive-metastore + test + + + org.apache.thrift + libthrift + test + + + org.apache.thrift + libfb303 + test + - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1d67b3ebb51b..fc742df73d73 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,8 +30,8 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, + SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} @@ -62,9 +62,21 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + // Default to twice the number of executors (twice the maximum number of executors if dynamic + // allocation is enabled), with a minimum of 3. + + private val maxNumExecutorFailures = { + val defaultKey = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + "spark.dynamicAllocation.maxExecutors" + } else { + "spark.executor.instances" + } + val effectiveNumExecutors = sparkConf.getInt(defaultKey, 0) + val defaultMaxNumExecutorFailures = math.max(3, 2 * effectiveNumExecutors) + + sparkConf.getInt("spark.yarn.max.executor.failures", defaultMaxNumExecutorFailures) + } @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -75,8 +87,27 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + + // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() + // Steady state heartbeat interval. We want to be reasonably responsive without causing too many + // requests to RM. + private val heartbeatInterval = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) + } + + // Initial wait interval before allocator poll, to allow for quicker ramp up when executors are + // being requested. + private val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + + // Next wait interval before allocator poll. + private var nextAllocationInterval = initialAllocationInterval + // Fields used in client mode. private var rpcEnv: RpcEnv = null private var amEndpoint: RpcEndpointRef = _ @@ -86,6 +117,10 @@ private[spark] class ApplicationMaster( private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + def getAttemptId(): ApplicationAttemptId = { + client.getAttemptId() + } + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -111,7 +146,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -198,7 +234,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code @@ -253,7 +289,6 @@ private[spark] class ApplicationMaster( driverRef, yarnConf, _sparkConf, - if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, securityMgr) @@ -309,7 +344,8 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr, + clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -319,19 +355,6 @@ private[spark] class ApplicationMaster( } private def launchReporterThread(): Thread = { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) - - // we want to check more frequently for pending containers - val initialAllocationInterval = math.min(heartbeatInterval, - sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) - - var nextAllocationInterval = initialAllocationInterval - // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -343,7 +366,7 @@ private[spark] class ApplicationMaster( if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, - "Max number of executor failures reached") + s"Max number of executor failures ($maxNumExecutorFailures) reached") } else { logDebug("Sending progress") allocator.allocateResources() @@ -353,7 +376,14 @@ private[spark] class ApplicationMaster( case i: InterruptedException => case e: Throwable => { failureCount += 1 - if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + // this exception was introduced in hadoop 2.4 and this code would not compile + // with earlier versions if we refer it directly. + if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" == + e.getClass().getName()) { + logError("Exception from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, + e.getMessage) + } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + s"$failureCount time(s) from Reporter thread.") @@ -363,20 +393,20 @@ private[spark] class ApplicationMaster( } } try { - val numPendingAllocate = allocator.getNumPendingAllocate - val sleepInterval = - if (numPendingAllocate > 0) { - val currentAllocationInterval = - math.min(heartbeatInterval, nextAllocationInterval) - nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow - currentAllocationInterval - } else { - nextAllocationInterval = initialAllocationInterval - heartbeatInterval - } - logDebug(s"Number of pending allocations is $numPendingAllocate. " + - s"Sleeping for $sleepInterval.") + val numPendingAllocate = allocator.getPendingAllocate.size allocatorLock.synchronized { + val sleepInterval = + if (numPendingAllocate > 0 || allocator.getNumPendingLossReasonRequests > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") allocatorLock.wait(sleepInterval) } } catch { @@ -493,7 +523,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => @@ -529,6 +558,10 @@ private[spark] class ApplicationMaster( e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class + case SparkUserAppException(exitCode) => + val msg = s"User application exited with status $exitCode" + logError(msg) + finish(FinalApplicationStatus.FAILED, exitCode, msg) case cause: Throwable => logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, @@ -544,6 +577,11 @@ private[spark] class ApplicationMaster( userThread } + private def resetAllocatorInterval(): Unit = allocatorLock.synchronized { + nextAllocationInterval = initialAllocationInterval + allocatorLock.notifyAll() + } + /** * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ @@ -553,7 +591,6 @@ private[spark] class ApplicationMaster( override def onStart(): Unit = { driver.send(RegisterClusterManager(self)) - } override def receive: PartialFunction[Any, Unit] = { @@ -566,17 +603,16 @@ private[spark] class ApplicationMaster( case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { case Some(a) => - allocatorLock.synchronized { - if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, - localityAwareTasks, hostToLocalTaskCount)) { - allocatorLock.notifyAll() - } + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { + resetAllocatorInterval() } + context.reply(true) case None => logWarning("Container allocator is not ready to request executors yet.") + context.reply(false) } - context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -585,13 +621,22 @@ private[spark] class ApplicationMaster( case None => logWarning("Container allocator is not ready to kill executors yet.") } context.reply(true) + + case GetExecutorLossReason(eid) => + Option(allocator) match { + case Some(a) => + a.enqueueGetLossReasonRequest(eid, context) + resetAllocatorInterval() + case None => + logWarning("Container allocator is not ready to find executor loss reasons yet.") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") // In cluster mode, do not rely on the disassociated event to exit // This avoids potentially reporting incorrect exit codes if the driver fails if (!isClusterMode) { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) } } @@ -629,6 +674,10 @@ object ApplicationMaster extends Logging { master.sparkContextStopped(sc) } + private[spark] def getAttemptId(): ApplicationAttemptId = { + master.getAttemptId + } + } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 37f793763367..17d9943c795e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail @@ -110,9 +105,9 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println System.exit(exitCode) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index fc11bbf97e2e..7742ec92eb4e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -25,7 +25,7 @@ import java.security.PrivilegedExceptionAction import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} @@ -54,8 +54,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils private[spark] class Client( @@ -69,8 +70,6 @@ private[spark] class Client( def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) - def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) private var credentials: Credentials = null @@ -83,10 +82,31 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = { + if (isClusterMode && appId != null) { + yarnClient.killApplication(appId) + } else { + setState(SparkAppHandle.State.KILLED) + stop() + } + } + } private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - def stop(): Unit = yarnClient.stop() + private var appId: ApplicationId = null + + def reportLauncherState(state: SparkAppHandle.State): Unit = { + launcherBackend.setState(state) + } + + def stop(): Unit = { + launcherBackend.close() + yarnClient.stop() + // Unset YARN mode system env variable, to allow switching between cluster types. + System.clearProperty("SPARK_YARN_MODE") + } /** * Submit an application running our ApplicationMaster to the ResourceManager. @@ -98,6 +118,7 @@ private[spark] class Client( def submitApplication(): ApplicationId = { var appId: ApplicationId = null try { + launcherBackend.connect() // Setup the credentials before doing anything else, // so we have don't have issues at any point. setupCredentials() @@ -111,6 +132,8 @@ private[spark] class Client( val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() appId = newAppResponse.getApplicationId() + reportLauncherState(SparkAppHandle.State.SUBMITTED) + launcherBackend.setAppId(appId.toString()) // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) @@ -163,15 +186,70 @@ private[spark] class Client( appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") + sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS) + .map(StringUtils.getTrimmedStringCollection(_)) + .filter(!_.isEmpty()) + .foreach { tagCollection => + try { + // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use + // reflection to set it, printing a warning if a tag was specified but the YARN version + // doesn't support it. + val method = appContext.getClass().getMethod( + "setApplicationTags", classOf[java.util.Set[String]]) + method.invoke(appContext, new java.util.HashSet[String](tagCollection)) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " + + "YARN does not support it") + } + } sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match { case Some(v) => appContext.setMaxAppAttempts(v) case None => logDebug("spark.yarn.maxAppAttempts is not set. " + "Cluster's default value will be used.") } + + if (sparkConf.contains("spark.yarn.am.attemptFailuresValidityInterval")) { + try { + val interval = sparkConf.getTimeAsMs("spark.yarn.am.attemptFailuresValidityInterval") + val method = appContext.getClass().getMethod( + "setAttemptFailuresValidityInterval", classOf[Long]) + method.invoke(appContext, interval: java.lang.Long) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.attemptFailuresValidityInterval because the version " + + "of YARN does not support it") + } + } + val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } @@ -203,12 +281,15 @@ private[spark] class Client( val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + - s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, @@ -266,8 +347,8 @@ private[spark] class Client( // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokenForHiveMetastore(hadoopConf, credentials) - obtainTokenForHBase(hadoopConf, credentials) + obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) + obtainTokenForHBase(sparkConf, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -316,7 +397,8 @@ private[spark] class Client( destName: Option[String] = None, targetDir: Option[String] = None, appMasterOnly: Boolean = false): (Boolean, String) = { - val localURI = new URI(path.trim()) + val trimmedPath = path.trim() + val localURI = Utils.resolveURI(trimmedPath) if (localURI.getScheme != LOCAL_SCHEME) { if (addDistributedUri(localURI)) { val localPath = getQualifiedLocalPath(localURI, hadoopConf) @@ -332,7 +414,7 @@ private[spark] class Client( (false, null) } } else { - (true, path.trim()) + (true, trimmedPath) } } @@ -412,7 +494,7 @@ private[spark] class Client( } // Distribute an archive with Hadoop and Spark configuration for the AM. - val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), resType = LocalResourceType.ARCHIVE, destName = Some(LOCALIZED_CONF_DIR), appMasterOnly = true) @@ -440,6 +522,19 @@ private[spark] class Client( */ private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() + + // Uploading $SPARK_CONF_DIR/log4j.properties file to the distributed cache to make sure that + // the executors will use the latest configurations instead of the default values. This is + // required when user changes log4j.properties directly to set the log configurations. If + // configuration file is provided through --files then executors will be taking configurations + // from --files instead of $SPARK_CONF_DIR/log4j.properties. + val log4jFileName = "log4j.properties" + Option(Utils.getContextOrSparkClassLoader.getResource(log4jFileName)).foreach { url => + if (url.getProtocol == "file") { + hadoopConfFiles(log4jFileName) = new File(url.getPath) + } + } + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) @@ -492,7 +587,7 @@ private[spark] class Client( val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath YarnSparkHadoopUtil.get.obtainTokensForNamenodes( nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) - val t = creds.getAllTokens + val t = creds.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .head val newExpiration = t.renew(hadoopConf) @@ -553,10 +648,10 @@ private[spark] class Client( LOCALIZED_PYTHON_DIR) } (pySparkArchives ++ pyArchives).foreach { path => - val uri = new URI(path) + val uri = Utils.resolveURI(path) if (uri.getScheme != LOCAL_SCHEME) { pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), - new Path(path).getName()) + new Path(uri).getName()) } else { pythonPath += uri.getPath() } @@ -631,8 +726,8 @@ private[spark] class Client( distCacheMgr.setDistArchivesEnv(launchEnv) val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - amContainer.setEnvironment(launchEnv) + amContainer.setLocalResources(localResources.asJava) + amContainer.setEnvironment(launchEnv.asJava) val javaOpts = ListBuffer[String]() @@ -707,6 +802,7 @@ private[spark] class Client( // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts) val userClass = if (isClusterMode) { @@ -749,7 +845,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -764,7 +859,7 @@ private[spark] class Client( // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList - amContainer.setCommands(printableCommands) + amContainer.setCommands(printableCommands.asJava) logDebug("===============================================================================") logDebug("YARN AM launch context:") @@ -779,7 +874,8 @@ private[spark] class Client( // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) setupSecurityToken(amContainer) UserGroupInformation.getCurrentUser().addCredentials(credentials) @@ -856,6 +952,20 @@ private[spark] class Client( } } + if (lastState != state) { + state match { + case YarnApplicationState.RUNNING => + reportLauncherState(SparkAppHandle.State.RUNNING) + case YarnApplicationState.FINISHED => + reportLauncherState(SparkAppHandle.State.FINISHED) + case YarnApplicationState.FAILED => + reportLauncherState(SparkAppHandle.State.FAILED) + case YarnApplicationState.KILLED => + reportLauncherState(SparkAppHandle.State.KILLED) + case _ => + } + } + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -903,8 +1013,8 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - val appId = submitApplication() - if (fireAndForget) { + this.appId = submitApplication() + if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState logInfo(s"Application report for $appId (state: $state)") @@ -936,9 +1046,9 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.9-src.zip") require(py4jFile.exists(), - "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + "py4j-0.9-src.zip not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } @@ -946,6 +1056,7 @@ private[spark] class Client( } object Client extends Logging { + def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { logWarning("WARNING: This client is deprecated and will be removed in a " + @@ -958,6 +1069,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } @@ -982,6 +1097,10 @@ object Client extends Logging { // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + // Comma-separated list of strings to pass through as YARN application tags appearing + // in YARN ApplicationReports, which can be used for filtering when querying YARN. + val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags" + // Staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) @@ -1018,7 +1137,10 @@ object Client extends Logging { s"in favor of the $CONF_SPARK_JAR configuration variable.") System.getenv(ENV_SPARK_JAR) } else { - SparkContext.jarOfClass(this.getClass).head + SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " + + "find jar containing Spark classes. The jar can be defined using the " + + "spark.yarn.jar configuration option. If testing Spark, either set that option or " + + "make sure SPARK_PREPEND_CLASSES is not set.")) } } @@ -1072,20 +1194,10 @@ object Client extends Logging { triedDefault.toOption } - /** - * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. - * So we need to use reflection to retrieve it. - */ private[yarn] def getDefaultMRApplicationClasspath: Option[Seq[String]] = { val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - val value = if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray - } else { - field.get(null).asInstanceOf[Array[String]] - } - value.toSeq + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toSeq } recoverWith { case e: NoSuchFieldException => Success(Seq.empty[String]) } @@ -1129,17 +1241,28 @@ object Client extends Logging { } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { - val userClassPath = + // in order to properly add the app jar when user classpath is first + // we have to do the mainJar separate in order to send the right thing + // into addFileToClasspath + val mainJar = + if (args != null) { + getMainJarUri(Option(args.userJar)) + } else { + getMainJarUri(sparkConf.getOption(CONF_SPARK_USER_JAR)) + } + mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR, env)) + + val secondaryJars = if (args != null) { - getUserClasspath(Option(args.userJar), Option(args.addJars)) + getSecondaryJarUris(Option(args.addJars)) } else { - getUserClasspath(sparkConf) + getSecondaryJarUris(sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) } - userClassPath.foreach { x => - addFileToClasspath(sparkConf, x, null, env) + secondaryJars.foreach { x => + addFileToClasspath(sparkConf, conf, x, null, env) } } - addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) @@ -1152,16 +1275,20 @@ object Client extends Logging { * @param conf Spark configuration. */ def getUserClasspath(conf: SparkConf): Array[URI] = { - getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR), - conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + val mainUri = getMainJarUri(conf.getOption(CONF_SPARK_USER_JAR)) + val secondaryUris = getSecondaryJarUris(conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + (mainUri ++ secondaryUris).toArray } - private def getUserClasspath( - mainJar: Option[String], - secondaryJars: Option[String]): Array[URI] = { - val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_)) - val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) - (mainUri ++ secondaryUris).toArray + private def getMainJarUri(mainJar: Option[String]): Option[URI] = { + mainJar.flatMap { path => + val uri = Utils.resolveURI(path) + if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None + }.orElse(Some(new URI(APP_JAR))) + } + + private def getSecondaryJarUris(secondaryJars: Option[String]): Seq[URI] = { + secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) } /** @@ -1170,15 +1297,17 @@ object Client extends Logging { * If an alternate name for the file is given, and it's not a "local:" file, the alternate * name will be added to the classpath (relative to the job's work directory). * - * If not a "local:" file and no alternate name, the environment is not modified. + * If not a "local:" file and no alternate name, the linkName will be added to the classpath. * - * @param conf Spark configuration. - * @param uri URI to add to classpath (optional). - * @param fileName Alternate name for the file (optional). - * @param env Map holding the environment variables. + * @param conf Spark configuration. + * @param hadoopConf Hadoop configuration. + * @param uri URI to add to classpath (optional). + * @param fileName Alternate name for the file (optional). + * @param env Map holding the environment variables. */ private def addFileToClasspath( conf: SparkConf, + hadoopConf: Configuration, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { @@ -1187,6 +1316,11 @@ object Client extends Logging { } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) + } else if (uri != null) { + val localPath = getQualifiedLocalPath(uri, hadoopConf) + val linkName = Option(uri.getFragment()).getOrElse(localPath.getName()) + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), linkName), env) } } @@ -1202,11 +1336,11 @@ object Client extends Logging { * * This method uses two configuration values: * - * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may - * only be valid in the gateway node. - * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may - * contain, for example, env variable references, which will be expanded by the NMs when - * starting containers. + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. * * If either config is not available, the input path is returned. */ @@ -1223,93 +1357,28 @@ object Client extends Logging { /** * Obtains token for the Hive metastore and adds them to the credentials. */ - private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { - if (UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hive = hiveClass.getMethod("get").invoke(null) - - val hiveConf = hiveClass.getMethod("getConf").invoke(hive) - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - - val hiveConfGet = (param: String) => Option(hiveConfClass - .getMethod("get", classOf[java.lang.String]) - .invoke(hiveConf, param)) - - val metastore_uri = hiveConfGet("hive.metastore.uris") - - // Check for local metastore - if (metastore_uri != None && metastore_uri.get.toString.size > 0) { - val metastore_kerberos_principal_conf_var = mirror.classLoader - .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") - .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString - - val principal = hiveConfGet(metastore_kerberos_principal_conf_var) - - val username = Option(UserGroupInformation.getCurrentUser().getUserName) - if (principal != None && username != None) { - val tokenStr = hiveClass.getMethod("getDelegationToken", - classOf[java.lang.String], classOf[java.lang.String]) - .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] - - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - credentials.addToken(new Text("hive.server2.delegation.token"), hive2Token) - logDebug("Added hive.Server2.delegation.token to conf.") - hiveClass.getMethod("closeCurrent").invoke(null) - } else { - logError("Username or principal == NULL") - logError(s"""username=${username.getOrElse("(NULL)")}""") - logError(s"""principal=${principal.getOrElse("(NULL)")}""") - throw new IllegalArgumentException("username and/or principal is equal to null!") - } - } else { - logDebug("HiveMetaStore configured in localmode") - } - } catch { - case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e: Exception => { logError("Unexpected Exception " + e) - throw new RuntimeException("Unexpected exception", e) - } + private def obtainTokenForHiveMetastore( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials) { + if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { + YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(conf).foreach { + credentials.addToken(new Text("hive.server2.delegation.token"), _) } } } /** - * Obtain security token for HBase. + * Obtain a security token for HBase. */ - def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val confCreate = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). - getMethod("create", classOf[Configuration]) - val obtainToken = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). - getMethod("obtainToken", classOf[Configuration]) - - logDebug("Attempting to fetch HBase security token.") - - val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] - if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - logInfo("Added HBase security token to credentials.") - } - } catch { - case e: java.lang.NoSuchMethodException => - logInfo("HBase Method not found: " + e) - case e: java.lang.ClassNotFoundException => - logDebug("HBase Class not found: " + e) - case e: java.lang.NoClassDefFoundError => - logDebug("HBase Class not found: " + e) - case e: Exception => - logError("Exception when obtaining HBase security token: " + e) + def obtainTokenForHBase( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials): Unit = { + if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { + YarnSparkHadoopUtil.get.obtainTokenForHBase(conf).foreach { token => + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") } } } @@ -1378,4 +1447,13 @@ object Client extends Logging { components.mkString(Path.SEPARATOR) } + /** + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ + def shouldGetTokens(conf: SparkConf, service: String): Boolean = { + conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 20d63d40cf60..a9f437435735 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -53,8 +53,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -82,22 +81,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orNull // If dynamic allocation is enabled, start at the configured initial number of executors. // Default to minExecutors if no initialExecutors is set. - if (isDynamicAllocationEnabled) { - val minExecutorsConf = "spark.dynamicAllocation.minExecutors" - val initialExecutorsConf = "spark.dynamicAllocation.initialExecutors" - val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors" - val minNumExecutors = sparkConf.getInt(minExecutorsConf, 0) - val initialNumExecutors = sparkConf.getInt(initialExecutorsConf, minNumExecutors) - val maxNumExecutors = sparkConf.getInt(maxExecutorsConf, Integer.MAX_VALUE) - - // If defined, initial executors must be between min and max - if (initialNumExecutors < minNumExecutors || initialNumExecutors > maxNumExecutors) { - throw new IllegalArgumentException( - s"$initialExecutorsConf must be between $minExecutorsConf and $maxNumExecutors!") - } - - numExecutors = initialNumExecutors - } + numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors) principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) .orNull @@ -196,11 +180,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 52580deb372c..2232ffba473b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -20,14 +20,13 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI import java.nio.ByteBuffer +import java.util.Collections -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation @@ -39,7 +38,9 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils class ExecutorRunnable( container: Container, @@ -74,9 +75,9 @@ class ExecutorRunnable( .asInstanceOf[ContainerLaunchContext] val localResources = prepareLocalResources - ctx.setLocalResources(localResources) + ctx.setLocalResources(localResources.asJava) - ctx.setEnvironment(env) + ctx.setEnvironment(env.asJava) val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() @@ -96,8 +97,9 @@ class ExecutorRunnable( |=============================================================================== """.stripMargin) - ctx.setCommands(commands) - ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + ctx.setCommands(commands.asJava) + ctx.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava) // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret @@ -112,7 +114,7 @@ class ExecutorRunnable( // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) } - ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes)) } // Send the start request to the ContainerManager @@ -198,6 +200,7 @@ class ExecutorRunnable( // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts) val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = @@ -217,7 +220,7 @@ class ExecutorRunnable( // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - "-XX:OnOutOfMemoryError='kill %p'") ++ + YarnSparkHadoopUtil.getOutOfMemoryErrorArgument) ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", "--driver-url", masterAddress.toString, @@ -314,7 +317,8 @@ class ExecutorRunnable( env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } env } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index 081780204e42..2ec189de7c91 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -18,9 +18,11 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.{ArrayBuffer, HashMap, Set} +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf @@ -30,8 +32,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack /** * This strategy is calculating the optimal locality preferences of YARN containers by considering * the node ratio of pending tasks, number of required cores/containers and and locality of current - * existing containers. The target of this algorithm is to maximize the number of tasks that - * would run locally. + * existing and pending allocated containers. The target of this algorithm is to maximize the number + * of tasks that would run locally. * * Consider a situation in which we have 20 tasks that require (host1, host2, host3) * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores @@ -91,6 +93,11 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( * @param numLocalityAwareTasks number of locality required tasks * @param hostToLocalTaskCount a map to store the preferred hostname and possible task * numbers running on it, used as hints for container allocation + * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the + * expected locality preference by considering the existing + * containers + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. * @return node localities and rack localities, each locality is an array of string, * the length of localities is the same as number of containers */ @@ -98,10 +105,12 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( numContainer: Int, numLocalityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int], - allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]], + localityMatchedPendingAllocations: Seq[ContainerRequest] ): Array[ContainerLocalityPreferences] = { val updatedHostToContainerCount = expectedHostToContainerCount( - numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap, + localityMatchedPendingAllocations) val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum // The number of containers to allocate, divided into two groups, one with preferred locality, @@ -158,20 +167,28 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( * @param localityAwareTasks number of locality aware tasks * @param hostToLocalTaskCount a map to store the preferred hostname and possible task * numbers running on it, used as hints for container allocation + * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the + * expected locality preference by considering the existing + * containers + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. * @return a map with hostname as key and required number of containers on this host as value */ private def expectedHostToContainerCount( localityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int], - allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]], + localityMatchedPendingAllocations: Seq[ContainerRequest] ): Map[String, Int] = { val totalLocalTaskNum = hostToLocalTaskCount.values.sum + val pendingHostToContainersMap = pendingHostToContainerCount(localityMatchedPendingAllocations) + hostToLocalTaskCount.map { case (host, count) => val expectedCount = count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum - val existedCount = allocatedHostToContainersMap.get(host) - .map(_.size) - .getOrElse(0) + // Take the locality of pending containers into consideration + val existedCount = allocatedHostToContainersMap.get(host).map(_.size).getOrElse(0) + + pendingHostToContainersMap.getOrElse(host, 0.0) // If existing container can not fully satisfy the expected number of container, // the required container number is expected count minus existed count. Otherwise the @@ -179,4 +196,31 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) } } + + /** + * According to the locality ratio and number of container requests, calculate the host to + * possible number of containers for pending allocated containers. + * + * If current locality ratio of hosts is: Host1 : Host2 : Host3 = 20 : 20 : 10, + * and pending container requests is 3, so the possible number of containers on + * Host1 : Host2 : Host3 will be 1.2 : 1.2 : 0.6. + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. + * @return a Map with hostname as key and possible number of containers on this host as value + */ + private def pendingHostToContainerCount( + localityMatchedPendingAllocations: Seq[ContainerRequest]): Map[String, Double] = { + val pendingHostToContainerCount = new HashMap[String, Int]() + localityMatchedPendingAllocations.foreach { cr => + cr.getNodes.asScala.foreach { n => + val count = pendingHostToContainerCount.getOrElse(n, 0) + 1 + pendingHostToContainerCount(n) = count + } + } + + val possibleTotalContainerNum = pendingHostToContainerCount.values.sum + val localityMatchedPendingNum = localityMatchedPendingAllocations.size.toDouble + pendingHostToContainerCount.mapValues(_ * localityMatchedPendingNum / possibleTotalContainerNum) + .toMap + } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 59caa787b6e2..4e044aa4788d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,10 +21,9 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import com.google.common.util.concurrent.ThreadFactoryBuilder +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ @@ -34,11 +33,12 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor +import org.apache.spark.util.ThreadUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -86,7 +86,17 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a + // list of requesters that should be responded to once we find out why the given executor + // was lost. + private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + + // Maintain loss reasons for already released executors, it will be added when executor loss + // reason is got from AM-RM call, and be removed after querying this loss reason. + private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason] // Keep track of which container is running which executor to remove the executors later // Visible for testing. @@ -105,13 +115,9 @@ private[yarn] class YarnAllocator( // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) - private val launcherPool = new ThreadPoolExecutor( - // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue - sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, - 1, TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable](), - new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) - launcherPool.allowCoreThreadTimeOut(true) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( + "ContainerLauncher", + sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25)) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) @@ -149,15 +155,19 @@ private[yarn] class YarnAllocator( def getNumExecutorsFailed: Int = numExecutorsFailed /** - * Number of container requests that have not yet been fulfilled. + * A sequence of pending container requests that have not yet been fulfilled. */ - def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST) + def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST) /** - * Number of container requests at the given location that have not yet been fulfilled. + * A sequence of pending container requests at the given location that have not yet been + * fulfilled. */ - private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum + private def getPendingAtLocation(location: String): Seq[ContainerRequest] = { + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala + .flatMap(_.asScala) + .toSeq + } /** * Request as many executors from the ResourceManager as needed to reach the desired total. If @@ -190,8 +200,7 @@ private[yarn] class YarnAllocator( */ def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.remove(executorId).get - containerIdToExecutorId.remove(container.getId) + val container = executorIdToContainer.get(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -224,15 +233,13 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers.asScala) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - - processCompletedContainers(completedContainers) - + processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) } @@ -245,26 +252,37 @@ private[yarn] class YarnAllocator( * Visible for testing. */ def updateResourceRequests(): Unit = { - val numPendingAllocate = getNumPendingAllocate + val pendingAllocate = getPendingAllocate + val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning - // TODO. Consider locality preferences of pending container requests. - // Since the last time we made container requests, stages have completed and been submitted, - // and that the localities at which we requested our pending executors - // no longer apply to our current needs. We should consider to remove all outstanding - // container requests and add requests anew each time to avoid this. if (missing > 0) { logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") + // Split the pending container request into three groups: locality matched list, locality + // unmatched list and non-locality list. Take the locality matched container request into + // consideration of container placement, treat as allocated containers. + // For locality unmatched and locality free container requests, cancel these container + // requests, since required locality preference has been changed, recalculating using + // container placement strategy. + val (localityMatched, localityUnMatched, localityFree) = splitPendingAllocationsByLocality( + hostToLocalTaskCounts, pendingAllocate) + + // Remove the outdated container request and recalculate the requested container number + localityUnMatched.foreach(amClient.removeContainerRequest) + localityFree.foreach(amClient.removeContainerRequest) + val updatedNumContainer = missing + localityUnMatched.size + localityFree.size + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( - missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + updatedNumContainer, numLocalityAwareTasks, hostToLocalTaskCounts, + allocatedHostToContainersMap, localityMatched) for (locality <- containerLocalityPreferences) { val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last logInfo(s"Container request (host: $hostStr, capability: $resource)") } } else if (missing < 0) { @@ -273,7 +291,8 @@ private[yarn] class YarnAllocator( val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) if (!matchingRequests.isEmpty) { - matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + matchingRequests.iterator().next().asScala + .take(numToCancel).foreach(amClient.removeContainerRequest) } else { logWarning("Expected to find pending requests, but found none.") } @@ -284,7 +303,7 @@ private[yarn] class YarnAllocator( * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. */ - protected def createContainerRequest( + private def createContainerRequest( resource: Resource, nodes: Array[String], racks: Array[String]): ContainerRequest = { @@ -423,39 +442,63 @@ private[yarn] class YarnAllocator( for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId val alreadyReleased = releasedContainers.remove(containerId) - if (!alreadyReleased) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + val onHostStr = hostOpt.map(host => s" on host: $host").getOrElse("") + val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 - logInfo("Completed container %s (state: %s, exit status: %s)".format( + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, + onHostStr, completedContainer.getState, completedContainer.getExitStatus)) // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { - logInfo("Container preempted: " + containerId) - } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed += 1 + // now I think its ok as none of the containers are expected to exit. + val exitStatus = completedContainer.getExitStatus + val (exitCausedByApp, containerExitReason) = exitStatus match { + case ContainerExitStatus.SUCCESS => + (false, s"Executor for container $containerId exited because of a YARN event (e.g., " + + "pre-emption) and not because of an error in the running job.") + case ContainerExitStatus.PREEMPTED => + // Preemption is not the fault of the running tasks, since YARN preempts containers + // merely to do resource sharing, and tasks that fail due to preempted executors could + // just as easily finish on any other executor. See SPARK-8167. + (false, s"Container ${containerId}${onHostStr} was preempted.") + // Should probably still count memory exceeded exit codes towards task failures + case VMEM_EXCEEDED_EXIT_CODE => + (true, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + case PMEM_EXCEEDED_EXIT_CODE => + (true, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + case _ => + numExecutorsFailed += 1 + (true, "Container marked as failed: " + containerId + onHostStr + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + } + if (exitCausedByApp) { + logWarning(containerExitReason) + } else { + logInfo(containerExitReason) + } + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) + } else { + // If we have already released this container, then it must mean + // that the driver has explicitly requested it to be killed + ExecutorExited(completedContainer.getExitStatus, exitCausedByApp = false, + s"Container $containerId exited from explicit termination request.") } - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get - + for { + host <- hostOpt + containerSet <- allocatedHostToContainersMap.get(host) + } { containerSet.remove(containerId) if (containerSet.isEmpty) { allocatedHostToContainersMap.remove(host) @@ -468,18 +511,50 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - + pendingLossReasonRequests.remove(eid) match { + case Some(pendingRequests) => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + + case None => + // We cannot find executor for pending reasons. This is because completed container + // is processed before querying pending result. We should store it for later query. + // This is usually happened when explicitly killing a container, the result will be + // returned in one AM-RM communication. So query RPC will be later than this completed + // container process. + releasedExecutorLossReasons.put(eid, exitReason) + } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) // Notify backend about the failure of the executor numUnexpectedContainerRelease += 1 - driverRef.send(RemoveExecutor(eid, - s"Yarn deallocated the executor $eid (container $containerId)")) + driverRef.send(RemoveExecutor(eid, exitReason)) } } } } + /** + * Register that some RpcCallContext has asked the AM why the executor was lost. Note that + * we can only find the loss reason to send back in the next call to allocateResources(). + */ + private[yarn] def enqueueGetLossReasonRequest( + eid: String, + context: RpcCallContext): Unit = synchronized { + if (executorIdToContainer.contains(eid)) { + pendingLossReasonRequests + .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else if (releasedExecutorLossReasons.contains(eid)) { + // Executor is already released explicitly before getting the loss reason, so directly send + // the pre-stored lost reason + context.reply(releasedExecutorLossReasons.remove(eid).get) + } else { + logWarning(s"Tried to get the loss reason for non-existent executor $eid") + context.sendFailure( + new SparkException(s"Fail to find loss reason for non-existent executor $eid")) + } + } + private def internalReleaseContainer(container: Container): Unit = { releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) @@ -487,6 +562,43 @@ private[yarn] class YarnAllocator( private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + private[yarn] def getNumPendingLossReasonRequests: Int = synchronized { + pendingLossReasonRequests.size + } + + /** + * Split the pending container requests into 3 groups based on current localities of pending + * tasks. + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. + * @param pendingAllocations A sequence of pending allocation container request. + * @return A tuple of 3 sequences, first is a sequence of locality matched container + * requests, second is a sequence of locality unmatched container requests, and third is a + * sequence of locality free container requests. + */ + private def splitPendingAllocationsByLocality( + hostToLocalTaskCount: Map[String, Int], + pendingAllocations: Seq[ContainerRequest] + ): (Seq[ContainerRequest], Seq[ContainerRequest], Seq[ContainerRequest]) = { + val localityMatched = ArrayBuffer[ContainerRequest]() + val localityUnMatched = ArrayBuffer[ContainerRequest]() + val localityFree = ArrayBuffer[ContainerRequest]() + + val preferredHosts = hostToLocalTaskCount.keySet + pendingAllocations.foreach { cr => + val nodes = cr.getNodes + if (nodes == null) { + localityFree += cr + } else if (nodes.asScala.toSet.intersect(preferredHosts).nonEmpty) { + localityMatched += cr + } else { + localityUnMatched += cr + } + } + + (localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq) + } + } private object YarnAllocator { @@ -495,6 +607,8 @@ private object YarnAllocator { Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") val VMEM_EXCEEDED_PATTERN = Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + val VMEM_EXCEEDED_EXIT_CODE = -103 + val PMEM_EXCEEDED_EXIT_CODE = -104 def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 4999f9c06210..d2a211f6711f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,17 +19,15 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -51,7 +49,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. - * @param preferredNodeLocations Map with hints about where to allocate containers. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. */ @@ -60,7 +57,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, - preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, uiHistoryAddress: String, securityMgr: SecurityManager @@ -108,8 +104,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", classOf[Configuration]) val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] - val hosts = proxies.map { proxy => proxy.split(":")(0) } - val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) } catch { case e: NoSuchMethodException => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 68d01c17ef72..36a2d6142988 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,18 +18,22 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 import java.util.regex.Matcher import java.util.regex.Pattern import scala.collection.mutable.HashMap +import scala.reflect.runtime._ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -37,6 +41,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, P import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils @@ -77,7 +82,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { override def addSecretKeyToUserCredentials(key: String, secret: String) { val creds = new Credentials() - creds.addSecretKey(new Text(key), secret.getBytes("utf-8")) + creds.addSecretKey(new Text(key), secret.getBytes(UTF_8)) addCurrentUserCredentials(creds) } @@ -141,6 +146,125 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) ConverterUtils.toContainerId(containerIdString) } + + /** + * Obtains token for the Hive metastore, using the current user as the principal. + * Some exceptions are caught and downgraded to a log message. + * @param conf hadoop configuration; the Hive configuration will be based on this + * @return a token, or `None` if there's no need for a token (no metastore URI or principal + * in the config), or if a binding exception was caught and downgraded. + */ + def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { + try { + obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + } catch { + case e: ClassNotFoundException => + logInfo(s"Hive class not found $e") + logDebug("Hive class not found", e) + None + } + } + + /** + * Inner routine to obtains token for the Hive metastore; exceptions are raised on any problem. + * @param conf hadoop configuration; the Hive configuration will be based on this. + * @param username the username of the principal requesting the delegating token. + * @return a delegation token + */ + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, + username: String): Option[Token[DelegationTokenIdentifier]] = { + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) + + // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down + // to a Configuration and used without reflection + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + // using the (Configuration, Class) constructor allows the current configuratin to be included + // in the hive config. + val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], + classOf[Object].getClass) + val hiveConf = ctor.newInstance(conf, hiveConfClass).asInstanceOf[Configuration] + val metastoreUri = hiveConf.getTrimmed("hive.metastore.uris", "") + + // Check for local metastore + if (metastoreUri.nonEmpty) { + require(username.nonEmpty, "Username undefined") + val principalKey = "hive.metastore.kerberos.principal" + val principal = hiveConf.getTrimmed(principalKey, "") + require(principal.nonEmpty, "Hive principal $principalKey undefined") + logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val closeCurrent = hiveClass.getMethod("closeCurrent") + try { + // get all the instance methods before invoking any + val getDelegationToken = hiveClass.getMethod("getDelegationToken", + classOf[String], classOf[String]) + val getHive = hiveClass.getMethod("get", hiveConfClass) + + // invoke + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } finally { + Utils.tryLogNonFatalError { + closeCurrent.invoke(null) + } + } + } else { + logDebug("HiveMetaStore configured in localmode") + None + } + } + + /** + * Obtain a security token for HBase. + * + * Requirements + * + * 1. `"hbase.security.authentication" == "kerberos"` + * 2. The HBase classes `HBaseConfiguration` and `TokenUtil` could be loaded + * and invoked. + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if the requirements were met, `None` if not. + */ + def obtainTokenForHBase(conf: Configuration): Option[Token[TokenIdentifier]] = { + try { + obtainTokenForHBaseInner(conf) + } catch { + case e: ClassNotFoundException => + logInfo(s"HBase class not found $e") + logDebug("HBase class not found", e) + None + } + } + + /** + * Obtain a security token for HBase if `"hbase.security.authentication" == "kerberos"` + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if one was needed + */ + def obtainTokenForHBaseInner(conf: Configuration): Option[Token[TokenIdentifier]] = { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + logDebug("Attempting to fetch HBase security token.") + Some(obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]]) + } else { + None + } + } + } object YarnSparkHadoopUtil { @@ -219,26 +343,61 @@ object YarnSparkHadoopUtil { } } + /** + * The handler if an OOM Exception is thrown by the JVM must be configured on Windows + * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'. + * + * As the JVM interprets both %p and %%p as the same, we can use either of them. However, + * some tests on Windows computers suggest, that the JVM only accepts '%%p'. + * + * Furthermore, the behavior of the character '%' on the Windows command line differs from + * the behavior of '%' in a .cmd file: it gets interpreted as an incomplete environment + * variable. Windows .cmd files escape a '%' by '%%'. Thus, the correct way of writing + * '%%p' in an escaped way is '%%%%p'. + * + * @return The correct OOM Error handler JVM option, platform dependent. + */ + def getOutOfMemoryErrorArgument : String = { + if (Utils.isWindows) { + escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") + } else { + "-XX:OnOutOfMemoryError='kill %p'" + } + } + /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands - * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The - * argument is enclosed in single quotes and some key characters are escaped. + * using either + * + * (Unix-based) `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. + * The argument is enclosed in single quotes and some key characters are escaped. + * + * (Windows-based) part of a .cmd file in which case windows escaping for each argument must be + * applied. Windows is quite lenient, however it is usually Java that causes trouble, needing to + * distinguish between arguments starting with '-' and class names. If arguments are surrounded + * by ' java takes the following string as is, hence an argument is mistakenly taken as a class + * name which happens to start with a '-'. The way to avoid this, is to surround nothing with + * a ', but instead with a ". * * @param arg A single argument. * @return Argument quoted for execution via Yarn's generated shell script. */ def escapeForShell(arg: String): String = { if (arg != null) { - val escaped = new StringBuilder("'") - for (i <- 0 to arg.length() - 1) { - arg.charAt(i) match { - case '$' => escaped.append("\\$") - case '"' => escaped.append("\\\"") - case '\'' => escaped.append("'\\''") - case c => escaped.append(c) + if (Utils.isWindows) { + YarnCommandBuilderUtils.quoteForBatchScript(arg) + } else { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } } + escaped.append("'").toString() } - escaped.append("'").toString() } else { arg } @@ -278,5 +437,31 @@ object YarnSparkHadoopUtil { def getClassPathSeparator(): String = { classPathSeparatorField.get(null).asInstanceOf[String] } + + /** + * Getting the initial target number of executors depends on whether dynamic allocation is + * enabled. + * If not using dynamic allocation it gets the number of executors reqeusted by the user. + */ + def getInitialTargetExecutorNumber( + conf: SparkConf, + numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { + if (Utils.isDynamicAllocationEnabled(conf)) { + val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + val initialNumExecutors = + conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue) + require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, + s"initial executor number $initialNumExecutors must between min executor number" + + s"$minNumExecutors and max executor number $maxNumExecutors") + + initialNumExecutors + } else { + val targetNumExecutors = + sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) + // System property can override environment variable. + conf.getInt("spark.executor.instances", targetNumExecutors) + } + } } diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala new file mode 100644 index 000000000000..7d246bf40712 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +/** + * Exposes methods from the launcher library that are used by the YARN backend. + */ +private[spark] object YarnCommandBuilderUtils { + + def quoteForBatchScript(arg: String): String = { + CommandBuilderUtils.quoteForBatchScript(arg) + } + + /** + * Adds the perm gen configuration to the list of java options if needed and not yet added. + * + * Note that this method adds the option based on the local JVM version; if the node where + * the container is running has a different Java version, there's a risk that the option will + * not be added (e.g. if the AM is running Java 8 but the container's node is set up to use + * Java 7). + */ + def addPermGenSizeOpt(args: ListBuffer[String]): Unit = { + CommandBuilderUtils.addPermGenSizeOpt(args.asJava) + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala new file mode 100644 index 000000000000..c06452184539 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.util.Utils + +/** + * An extension service that can be loaded into a Spark YARN scheduler. + * A Service that can be started and stopped. + * + * 1. For implementations to be loadable by `SchedulerExtensionServices`, + * they must provide an empty constructor. + * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ +trait SchedulerExtensionService { + + /** + * Start the extension service. This should be a no-op if + * called more than once. + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit + + /** + * Stop the service + * The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ + def stop(): Unit +} + +/** + * Binding information for a [[SchedulerExtensionService]]. + * + * The attempt ID will be set if the service is started within a YARN application master; + * there is then a different attempt ID for every time that AM is restarted. + * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks + * this information. + * @param sparkContext current spark context + * @param applicationId YARN application ID + * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in + * cluster mode. + */ +case class SchedulerExtensionServiceBinding( + sparkContext: SparkContext, + applicationId: ApplicationId, + attemptId: Option[ApplicationAttemptId] = None) + +/** + * Container for [[SchedulerExtensionService]] instances. + * + * Loads Extension Services from the configuration property + * `"spark.yarn.services"`, instantiates and starts them. + * When stopped, it stops all child entries. + * + * The order in which child extension services are started and stopped + * is undefined. + */ +private[spark] class SchedulerExtensionServices extends SchedulerExtensionService + with Logging { + private var serviceOption: Option[String] = None + private var services: List[SchedulerExtensionService] = Nil + private val started = new AtomicBoolean(false) + private var binding: SchedulerExtensionServiceBinding = _ + + /** + * Binding operation will load the named services and call bind on them too; the + * entire set of services are then ready for `init()` and `start()` calls. + * + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit = { + if (started.getAndSet(true)) { + logWarning("Ignoring re-entrant start operation") + return + } + require(binding.sparkContext != null, "Null context parameter") + require(binding.applicationId != null, "Null appId parameter") + this.binding = binding + val sparkContext = binding.sparkContext + val appId = binding.applicationId + val attemptId = binding.attemptId + logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId") + + serviceOption = sparkContext.getConf.getOption(SchedulerExtensionServices.SPARK_YARN_SERVICES) + services = serviceOption + .map { s => + s.split(",").map(_.trim()).filter(!_.isEmpty) + .map { sClass => + val instance = Utils.classForName(sClass) + .newInstance() + .asInstanceOf[SchedulerExtensionService] + // bind this service + instance.start(binding) + logInfo(s"Service $sClass started") + instance + }.toList + }.getOrElse(Nil) + } + + /** + * Get the list of services. + * + * @return a list of services; Nil until the service is started + */ + def getServices: List[SchedulerExtensionService] = services + + /** + * Stop the services; idempotent. + * + */ + override def stop(): Unit = { + if (started.getAndSet(false)) { + logInfo(s"Stopping $this") + services.foreach { s => + Utils.tryLogNonFatalError(s.stop()) + } + } + } + + override def toString(): String = s"""SchedulerExtensionServices + |(serviceOption=$serviceOption, + | services=$services, + | started=$started)""".stripMargin +} + +private[spark] object SchedulerExtensionServices { + + /** + * A list of comma separated services to instantiate in the scheduler + */ + val SPARK_YARN_SERVICES = "spark.yarn.services" +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d97fa2e2151b..0e27a2665e93 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,10 +19,11 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -32,8 +33,7 @@ private[spark] class YarnClientSchedulerBackend( with Logging { private var client: Client = null - private var appId: ApplicationId = null - private var monitorThread: Thread = null + private var monitorThread: MonitorThread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -53,13 +53,12 @@ private[spark] class YarnClientSchedulerBackend( val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.submitApplication() + bindToYarn(client.submitApplication(), None) // SPARK-8687: Ensure all necessary properties have already been set before // we initialize our driver scheduler backend, which serves these properties // to the executors super.start() - waitForApplication() // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver @@ -81,8 +80,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -92,7 +89,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => @@ -118,8 +114,8 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def waitForApplication(): Unit = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -127,7 +123,35 @@ private[spark] class YarnClientSchedulerBackend( "It might have been killed or unable to launch application master.") } if (state == YarnApplicationState.RUNNING) { - logInfo(s"Application $appId has started running.") + logInfo(s"Application ${appId.get} has started running.") + } + } + + /** + * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's + * because the SparkContext is being shut down(sc.stop() called by user code), but if + * monitorApplication return, it means the Yarn application finished before sc.stop() was called, + * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted + * before SparkContext stops successfully. + */ + private class MonitorThread extends Thread { + private var allowInterrupt = true + + override def run() { + try { + val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + allowInterrupt = false + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") + } + } + + def stopMonitor(): Unit = { + if (allowInterrupt) { + this.interrupt() + } } } @@ -136,19 +160,9 @@ private[spark] class YarnClientSchedulerBackend( * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Thread = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val t = new Thread { - override def run() { - try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") - sc.stop() - } catch { - case e: InterruptedException => logInfo("Interrupting monitor thread") - } - } - } + private def asyncMonitorApplication(): MonitorThread = { + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) t @@ -160,19 +174,21 @@ private[spark] class YarnClientSchedulerBackend( override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") if (monitorThread != null) { - monitorThread.interrupt() + monitorThread.stopMonitor() } + + // Report a final state to the launcher if one is connected. This is needed since in client + // mode this backend doesn't let the app monitor loop run to completion, so it does not report + // the final state itself. + // + // Note: there's not enough information at this point to provide a better final state, + // so assume the application was successful. + client.reportLauncherState(SparkAppHandle.State.FINISHED) + super.stop() - client.stop() YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() + client.stop() logInfo("Stopped") } - override def applicationId(): String = { - Option(appId).map(_.toString).getOrElse { - logWarning("Application ID is not initialized yet.") - super.applicationId - } - } - } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 1aed5a167507..ced597bed36d 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,21 +17,13 @@ package org.apache.spark.scheduler.cluster -import java.net.NetworkInterface - import org.apache.hadoop.yarn.api.ApplicationConstants.Environment - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.yarn.api.records.NodeState -import org.apache.hadoop.yarn.client.api.YarnClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.{IntParam, Utils} +import org.apache.spark.util.Utils private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -39,32 +31,12 @@ private[spark] class YarnClusterSchedulerBackend( extends YarnSchedulerBackend(scheduler, sc) { override def start() { + val attemptId = ApplicationMaster.getAttemptId + bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() - totalExpectedExecutors = DEFAULT_NUMBER_EXECUTORS - if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) - .getOrElse(totalExpectedExecutors) - } - // System property can override environment variable. - totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } - override def applicationId(): String = - // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.id").getOrElse { - logError("Application ID is not set.") - super.applicationId - } - - override def applicationAttemptId(): Option[String] = - // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.attemptId").orElse { - logError("Application attempt ID is not set.") - super.applicationAttemptId - } - override def getDriverLogUrls: Option[Map[String, String]] = { var driverLogs: Option[Map[String, String]] = None try { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala new file mode 100644 index 000000000000..1431bceb256a --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.ui.JettyUtils +import org.apache.spark.util.{RpcUtils, ThreadUtils} + +/** + * Abstract Yarn scheduler backend that contains common logic + * between the client and cluster Yarn scheduler backends. + */ +private[spark] abstract class YarnSchedulerBackend( + scheduler: TaskSchedulerImpl, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + minRegisteredRatio = 0.8 + } + + protected var totalExpectedExecutors = 0 + + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) + + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) + + /** Application ID. */ + protected var appId: Option[ApplicationId] = None + + /** Attempt ID. This is unset for client-mode schedulers */ + private var attemptId: Option[ApplicationAttemptId] = None + + /** Scheduler extension services. */ + private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + + // Flag to specify whether this schedulerBackend should be reset. + private var shouldResetOnAmRegister = false + + /** + * Bind to YARN. This *must* be done before calling [[start()]]. + * + * @param appId YARN application ID + * @param attemptId Optional YARN attempt ID + */ + protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = { + this.appId = Some(appId) + this.attemptId = attemptId + } + + override def start() { + require(appId.isDefined, "application ID unset") + val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId) + services.start(binding) + super.start() + } + + override def stop(): Unit = { + try { + super.stop() + } finally { + services.stop() + } + } + + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * + * @return The application attempt id, if available. + */ + override def applicationAttemptId(): Option[String] = { + attemptId.map(_.toString) + } + + /** + * Get an application ID associated with the job. + * This returns the string value of [[appId]] if set, otherwise + * the locally-generated ID from the superclass. + * @return The application ID + */ + override def applicationId(): String = { + appId.map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } + + /** + * Request executors from the ApplicationMaster by specifying the total number desired. + * This includes executors already pending or running. + */ + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + yarnSchedulerEndpointRef.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + } + + /** + * Request that the ApplicationMaster kill the specified executors. + */ + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } + + /** + * Add filters to the SparkUI. + */ + private def addWebUIFilter( + filterName: String, + filterParams: Map[String, String], + proxyBase: String): Unit = { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + val hasFilter = + filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty + if (hasFilter) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed + * and re-registered itself to driver after a failure. The stale state in driver should be + * cleaned. + */ + override protected def reset(): Unit = { + super.reset() + sc.executorAllocationManager.foreach(_.reset()) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. If the executor exited for some reason unrelated to the running tasks + * (e.g., preemption), according to the application master, then we pass that information down + * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should + * not count towards a job failure. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + } + + /** + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. + */ + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None + + private val askAmThreadPool = + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => { + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + } + future onFailure { + case NonFatal(e) => { + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + } + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) + } + } + + override def receive: PartialFunction[Any, Unit] = { + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Option(am) + if (!shouldResetOnAmRegister) { + shouldResetOnAmRegister = true + } else { + // AM is already registered before, this potentially means that AM failed and + // a new one registered after the failure. This will only happen in yarn-client mode. + reset() + } + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + case RemoveExecutor(executorId, reason) => + logWarning(reason.toString) + removeExecutor(executorId, reason) + } + + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: RequestExecutors => + amEndpoint match { + case Some(am) => + Future { + context.reply(am.askWithRetry[Boolean](r)) + } onFailure { + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) + } + case None => + logWarning("Attempted to request executors before the AM has registered!") + context.reply(false) + } + + case k: KillExecutors => + amEndpoint match { + case Some(am) => + Future { + context.reply(am.askWithRetry[Boolean](k)) + } onFailure { + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) + } + case None => + logWarning("Attempted to kill executors before the AM has registered!") + context.reply(false) + } + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + amEndpoint = None + } + } + + override def onStop(): Unit = { + askAmThreadPool.shutdownNow() + } + } +} + +private[spark] object YarnSchedulerBackend { + val ENDPOINT_NAME = "YarnScheduler" +} diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b8a5dbf6373..6b9a799954bf 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -23,6 +23,9 @@ log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 000000000000..12494b01054b --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.launcher._ +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + |log4j.logger.org.apache.hadoop=WARN + |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.mortbay=WARN + |log4j.logger.org.spark-project.jetty=WARN + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + protected var hadoopConfDir: File = _ + private var logConfDir: File = _ + + def newYarnConfig(): YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + // Disable the disk utilization check to avoid the test hanging when people's disks are + // getting full. + val yarnConf = newYarnConfig() + yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage", + "100.0") + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConf) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[(String, String)] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map(), + extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) + val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv + + val launcher = new SparkLauncher(env.asJava) + if (klass.endsWith(".py")) { + launcher.setAppResource(klass) + } else { + launcher.setMainClass(klass) + launcher.setAppResource(fakeSparkJar.getAbsolutePath()) + } + launcher.setSparkHome(sys.props("spark.test.home")) + .setMaster(master) + .setConf("spark.executor.instances", "1") + .setPropertiesFile(propsFile) + .addAppArgs(appArgs.toArray: _*) + + sparkArgs.foreach { case (name, value) => + if (value != null) { + launcher.addSparkArg(name, value) + } else { + launcher.addSparkArg(name) + } + } + extraJars.foreach(launcher.addJar) + + val handle = launcher.startApplication() + try { + eventually(timeout(2 minutes), interval(1 second)) { + assert(handle.getState().isFinal()) + } + } finally { + handle.kill() + } + + handle.getState() + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { + checkResult(finalState, result, "success") + } + + protected def checkResult( + finalState: SparkAppHandle.State, + result: File, + expected: String): Unit = { + finalState should be (SparkAppHandle.State.FINISHED) + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + + protected def createConfFile( + extraClassPath: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): String = { + val props = new Properties() + props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val testClasspath = new TestClasspathBuilder() + .buildClassPath( + logConfDir.getAbsolutePath() + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator)) + .asScala + .mkString(File.pathSeparator) + + props.put("spark.driver.extraClassPath", testClasspath) + props.put("spark.executor.extraClassPath", testClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().asScala.foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + propsFile.getAbsolutePath() + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 837f8d3fa55a..e7f2501e7899 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try @@ -29,13 +29,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.Records import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { @@ -170,6 +173,39 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { cp should contain ("/remotePath/my1.jar") } + test("configuration and args propagate through createApplicationSubmissionContext") { + val conf = new Configuration() + // When parsing tags, duplicates and leading/trailing whitespace should be removed. + // Spaces between non-comma strings should be preserved as single tags. Empty strings may or + // may not be removed depending on the version of Hadoop being used. + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup") + .set("spark.yarn.maxAppAttempts", "42") + val args = new ClientArguments(Array( + "--name", "foo-test-app", + "--queue", "staging-queue"), sparkConf) + + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) + val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) + + val client = new Client(args, conf, sparkConf) + client.createApplicationSubmissionContext( + new YarnClientApplication(getNewApplicationResponse, appContext), + containerLaunchContext) + + appContext.getApplicationName should be ("foo-test-app") + appContext.getQueue should be ("staging-queue") + appContext.getAMContainerSpec should be (containerLaunchContext) + appContext.getApplicationType should be ("SPARK") + appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] + tags should contain allOf ("tag1", "dup", "tag2", "multi word") + tags.asScala.filter(_.nonEmpty).size should be (4) + } + appContext.getMaxAppAttempts should be (42) + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala index b7fe4ccc67a3..afb4b691b52d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.yarn +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite @@ -26,6 +27,9 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B private val yarnAllocatorSuite = new YarnAllocatorSuite import yarnAllocatorSuite._ + def createContainerRequest(nodes: Array[String]): ContainerRequest = + new ContainerRequest(containerResource, nodes, null, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) + override def beforeEach() { yarnAllocatorSuite.beforeEach() } @@ -44,7 +48,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array( Array("host3", "host4", "host5"), @@ -66,7 +71,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null, Array("host2", "host3"), Array("host2", "host3"))) @@ -86,7 +92,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) } @@ -105,7 +112,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null, null, null)) } @@ -118,8 +126,28 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 1, 0, Map.empty, handler.allocatedHostToContainersMap) + 1, 0, Map.empty, handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null)) } + + test("allocate locality preferred containers by considering the localities of pending requests") { + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val pendingAllocationRequests = Seq( + createContainerRequest(Array("host2", "host3")), + createContainerRequest(Array("host1", "host4"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, pendingAllocationRequests) + + assert(localities.map(_.nodes) === Array(Array("host3"))) + } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 58318bf9bcc0..bd80036c5cfa 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -87,16 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), @@ -115,7 +116,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(1) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) val container = createContainer("host1") handler.handleAllocatedContainers(Array(container)) @@ -133,7 +134,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host1") @@ -153,7 +154,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(2) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (2) + handler.getPendingAllocate.size should be (2) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -173,11 +174,11 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (3) + handler.getPendingAllocate.size should be (3) val container = createContainer("host1") handler.handleAllocatedContainers(Array(container)) @@ -188,18 +189,18 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) } test("decrease total requested executors to less than currently running") { val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (3) + handler.getPendingAllocate.size should be (3) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -209,7 +210,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (0) + handler.getPendingAllocate.size should be (0) handler.getNumExecutorsRunning should be (2) } @@ -217,7 +218,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -232,14 +233,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.updateResourceRequests() handler.processCompletedContainers(statuses.toSeq) handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) } test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -254,7 +255,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.processCompletedContainers(statuses.toSeq) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (2) + handler.getPendingAllocate.size should be (2) handler.getNumExecutorsFailed should be (2) handler.getNumUnexpectedContainerRelease should be (2) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 547863d9a073..6db012a77a93 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,25 +17,26 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.File import java.net.URL -import java.util.Properties -import java.util.concurrent.TimeUnit +import java.util.{HashMap => JHashMap, Properties} -import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 -import com.google.common.io.ByteStreams -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.tags.ExtendedYarnTest import org.apache.spark.util.Utils /** @@ -43,17 +44,10 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { - - // log4j configuration for the YARN containers, so that their output is collected - // by YARN instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin +@ExtendedYarnTest +class YarnClusterSuite extends BaseYarnClusterSuite { + + override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -82,65 +76,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | return 42 """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ - private var logConfDir: File = _ - - override def beforeAll() { - super.beforeAll() - - tempDir = Utils.createTempDir() - logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } - - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() - } - test("run Spark in yarn-client mode") { testBasicYarnApp(true) } @@ -151,10 +86,8 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. - val exception = intercept[SparkException] { - runSpark(false, mainClassName(YarnClusterDriver.getClass)) - fail("Spark application should have failed.") - } + val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) + finalState should be (SparkAppHandle.State.FAILED) } test("run Python application in yarn-client mode") { @@ -173,17 +106,59 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher testUseClassPathFirst(false) } + test("monitor app using launcher library") { + val env = new JHashMap[String, String]() + env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath()) + + val propsFile = createConfFile() + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf("spark.ui.enabled", "false") + .setPropertiesFile(propsFile) + .setMaster("yarn-client") + .setAppResource("spark-internal") + .setMainClass(mainClassName(YarnLauncherTestApp.getClass)) + .startApplication() + + try { + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.RUNNING) + } + + handle.getAppId() should not be (null) + handle.getAppId() should startWith ("application_") + handle.stop() + + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + private def testBasicYarnApp(clientMode: Boolean): Unit = { - var result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) - checkResult(result) + checkResult(finalState, result) } private def testPySpark(clientMode: Boolean): Unit = { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + // When running tests, let's not assume the user has built the assembly module, which also + // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the + // needed locations. + val sparkHome = sys.props("spark.test.home"); + val pythonPath = Seq( + s"$sparkHome/python/lib/py4j-0.9-src.zip", + s"$sparkHome/python") + val extraEnv = Map( + "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), + "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) + val moduleDir = if (clientMode) { // In client-mode, .py files added with --py-files are not visible in the driver. @@ -201,10 +176,11 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFiles), - appArgs = Seq(result.getAbsolutePath())) - checkResult(result) + val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files" -> pyFiles), + appArgs = Seq(result.getAbsolutePath()), + extraEnv = extraEnv) + checkResult(finalState, result) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { @@ -213,98 +189,15 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) val driverResult = File.createTempFile("driver", null, tempDir) val executorResult = File.createTempFile("executor", null, tempDir) - runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), extraClassPath = Seq(originalJar.getPath()), extraJars = Seq("local:" + userJar.getPath()), extraConf = Map( "spark.driver.userClassPathFirst" -> "true", "spark.executor.userClassPathFirst" -> "true")) - checkResult(driverResult, "OVERRIDDEN") - checkResult(executorResult, "OVERRIDDEN") - } - - private def runSpark( - clientMode: Boolean, - klass: String, - appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, - extraClassPath: Seq[String] = Nil, - extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() - - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) - - // SPARK-4267: make sure java options are propagated correctly. - props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") - props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - - yarnCluster.getConfig().foreach { e => - props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) - } - - sys.props.foreach { case (k, v) => - if (k.startsWith("spark.")) { - props.setProperty(k, v) - } - } - - extraConf.foreach { case (k, v) => props.setProperty(k, v) } - - val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) - props.store(writer, "Spark properties.") - writer.close() - - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - private def checkResult(result: File, expected: String): Unit = { - var resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - private def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") + checkResult(finalState, driverResult, "OVERRIDDEN") + checkResult(finalState, executorResult, "OVERRIDDEN") } } @@ -351,8 +244,8 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - sc.stop() Files.write(result, status, UTF_8) + sc.stop() } // verify log urls are present @@ -371,8 +264,8 @@ private object YarnClusterDriver extends Logging with Matchers { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) - assert(driverLogs.containsKey("stderr")) - assert(driverLogs.containsKey("stdout")) + assert(driverLogs.contains("stderr")) + assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") // Ensure that this is a valid URL, else this will throw an exception new URL(urlStr) @@ -384,19 +277,29 @@ private object YarnClusterDriver extends Logging with Matchers { } -private object YarnClasspathTest { +private object YarnClasspathTest extends Logging { + + var exitCode = 0 + + def error(m: String, ex: Throwable = null): Unit = { + logError(m, ex) + // scalastyle:off println + System.out.println(m) + if (ex != null) { + ex.printStackTrace(System.out) + } + // scalastyle:on println + } def main(args: Array[String]): Unit = { if (args.length != 2) { - // scalastyle:off println - System.err.println( + error( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) // scalastyle:on println - System.exit(1) } readResource(args(0)) @@ -406,6 +309,7 @@ private object YarnClasspathTest { } finally { sc.stop() } + System.exit(exitCode) } private def readResource(resultPath: String): Unit = { @@ -415,9 +319,29 @@ private object YarnClasspathTest { val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) result = new String(bytes, 0, bytes.length, UTF_8) + } catch { + case t: Throwable => + error(s"loading test.resource to $resultPath", t) + // set the exit code if not yet set + exitCode = 2 } finally { Files.write(result, new File(resultPath), UTF_8) } } } + +private object YarnLauncherTestApp { + + def main(args: Array[String]): Unit = { + // Do not stop the application; the test will stop it using the launcher lib. Just run a task + // that will prevent the process from exiting. + val sc = new SparkContext(new SparkConf()) + sc.parallelize(Seq(1)).foreach { i => + this.synchronized { + wait() + } + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 000000000000..c17e8695c24f --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,111 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} +import org.apache.spark.tags.ExtendedYarnTest + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +@ExtendedYarnTest +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def newYarnConfig(): YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(finalState, result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 49bee0866dd4..3fafc91a166a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} +import java.lang.reflect.InvocationTargetException import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -30,6 +33,7 @@ import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -233,4 +237,79 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") } + + test("check different hadoop utils based on env variable") { + try { + System.setProperty("SPARK_YARN_MODE", "true") + assert(SparkHadoopUtil.get.getClass === classOf[YarnSparkHadoopUtil]) + System.setProperty("SPARK_YARN_MODE", "false") + assert(SparkHadoopUtil.get.getClass === classOf[SparkHadoopUtil]) + } finally { + System.clearProperty("SPARK_YARN_MODE") + } + } + + test("Obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + val util = new YarnSparkHadoopUtil + assertNestedHiveException(intercept[InvocationTargetException] { + util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + }) + assertNestedHiveException(intercept[InvocationTargetException] { + util.obtainTokenForHiveMetastore(hadoopConf) + }) + } + + private def assertNestedHiveException(e: InvocationTargetException): Throwable = { + val inner = e.getCause + if (inner == null) { + fail("No inner cause", e) + } + if (!inner.isInstanceOf[HiveException]) { + fail("Not a hive exception", inner) + } + inner + } + + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + val util = new YarnSparkHadoopUtil + intercept[ClassNotFoundException] { + util.obtainTokenForHBaseInner(hadoopConf) + } + util.obtainTokenForHBase(hadoopConf) should be (None) + } + + // This test needs to live here because it depends on isYarnMode returning true, which can only + // happen in the YARN module. + test("security manager token generation") { + try { + System.setProperty("SPARK_YARN_MODE", "true") + val initial = SparkHadoopUtil.get + .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) + assert(initial === null || initial.length === 0) + + val conf = new SparkConf() + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val sm = new SecurityManager(conf) + + val generated = SparkHadoopUtil.get + .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) + assert(generated != null) + val genString = new Text(generated).toString() + assert(genString != "unused") + assert(sm.getSecretKey() === genString) + } finally { + // removeSecretKey() was only added in Hadoop 2.6, so instead we just set the secret + // to an empty string. + SparkHadoopUtil.get.addSecretKeyToUserCredentials(SecurityManager.SECRET_LOOKUP_KEY, "") + System.clearProperty("SPARK_YARN_MODE") + } + } + } diff --git a/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala new file mode 100644 index 000000000000..da9e8e21a26a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.{List => JList, Map => JMap} + +/** + * Exposes AbstractCommandBuilder to the YARN tests, so that they can build classpaths the same + * way other cluster managers do. + */ +private[spark] class TestClasspathBuilder extends AbstractCommandBuilder { + + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sys.props("spark.test.home")) + + override def buildClassPath(extraCp: String): JList[String] = super.buildClassPath(extraCp) + + /** Not used by the YARN tests. */ + override def buildCommand(env: JMap[String, String]): JList[String] = + throw new UnsupportedOperationException() + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 000000000000..94bf579dc824 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 000000000000..6aa8c814cd4f --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.setInt("spark.shuffle.service.port", 0) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala new file mode 100644 index 000000000000..db322cd18e15 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.File + +/** + * just a cheat to get package-visible members in tests + */ +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } + + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } + + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala new file mode 100644 index 000000000000..b4d1b0a3d22a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{LocalSparkContext, Logging, SparkConf, SparkContext, SparkFunSuite} + +/** + * Test the integration with [[SchedulerExtensionServices]] + */ +class ExtensionServiceIntegrationSuite extends SparkFunSuite + with LocalSparkContext with BeforeAndAfter + with Logging { + + val applicationId = new StubApplicationId(0, 1111L) + val attemptId = new StubApplicationAttemptId(applicationId, 1) + + /* + * Setup phase creates the spark context + */ + before { + val sparkConf = new SparkConf() + sparkConf.set(SchedulerExtensionServices.SPARK_YARN_SERVICES, + classOf[SimpleExtensionService].getName()) + sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite") + sc = new SparkContext(sparkConf) + } + + test("Instantiate") { + val services = new SchedulerExtensionServices() + assertResult(Nil, "non-nil service list") { + services.getServices + } + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + services.stop() + } + + test("Contains SimpleExtensionService Service") { + val services = new SchedulerExtensionServices() + try { + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + val serviceList = services.getServices + assert(serviceList.nonEmpty, "empty service list") + val (service :: Nil) = serviceList + val simpleService = service.asInstanceOf[SimpleExtensionService] + assert(simpleService.started.get, "service not started") + services.stop() + assert(!simpleService.started.get, "service not stopped") + } finally { + services.stop() + } + } +} + + diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala new file mode 100644 index 000000000000..9b8c98cda8da --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +private[spark] class SimpleExtensionService extends SchedulerExtensionService { + + /** started flag; set in the `start()` call, stopped in `stop()`. */ + val started = new AtomicBoolean(false) + + override def start(binding: SchedulerExtensionServiceBinding): Unit = { + started.set(true) + } + + override def stop(): Unit = { + started.set(false) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala new file mode 100644 index 000000000000..4b57b9509a65 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +/** + * A stub application ID; can be set in constructor and/or updated later. + * @param applicationId application ID + * @param attempt an attempt counter + */ +class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int) + extends ApplicationAttemptId { + + override def setApplicationId(appID: ApplicationId): Unit = { + applicationId = appID + } + + override def getAttemptId: Int = { + attempt + } + + override def setAttemptId(attemptId: Int): Unit = { + attempt = attemptId + } + + override def getApplicationId: ApplicationId = { + applicationId + } + + override def build(): Unit = { + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala new file mode 100644 index 000000000000..bffa0e09befd --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.ApplicationId + +/** + * Simple Testing Application Id; ID and cluster timestamp are set in constructor + * and cannot be updated. + * @param id app id + * @param clusterTimestamp timestamp + */ +private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId { + override def getId: Int = { + id + } + + override def getClusterTimestamp: Long = { + clusterTimestamp + } + + override def setId(id: Int): Unit = {} + + override def setClusterTimestamp(clusterTimestamp: Long): Unit = {} + + override def build(): Unit = {} +}

        Status
        Location
        Executor ID / Host
        Last Error Time
        Last Error Message
        {failureReasonSummary}{details} + {failureReasonSummary}{details} +